1use apalis_core::request::Parts;
2use apalis_core::task::attempt::Attempt;
3use apalis_core::task::task_id::TaskId;
4use apalis_core::{request::Request, worker::WorkerId};
5
6use serde::{Deserialize, Serialize};
7use sqlx::{Decode, Type};
8
9use crate::context::SqlContext;
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SqlRequest<T> {
13 pub req: Request<T, SqlContext>,
15 pub(crate) _priv: (),
16}
17
18impl<T> SqlRequest<T> {
19 pub fn new(req: Request<T, SqlContext>) -> Self {
21 SqlRequest { req, _priv: () }
22 }
23}
24
25#[cfg(feature = "sqlite")]
26#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
27impl<'r, T: Decode<'r, sqlx::Sqlite> + Type<sqlx::Sqlite>>
28 sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for SqlRequest<T>
29{
30 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
31 use chrono::DateTime;
32 use sqlx::Row;
33 use std::str::FromStr;
34
35 let job: T = row.try_get("job")?;
36 let task_id: TaskId =
37 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
38 index: "id".to_string(),
39 source: Box::new(e),
40 })?;
41 let mut parts = Parts::<SqlContext>::default();
42 parts.task_id = task_id;
43
44 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
45 parts.attempt = Attempt::new_with_value(attempt as usize);
46
47 let mut context = crate::context::SqlContext::new();
48
49 let run_at: i64 = row.try_get("run_at")?;
50 context.set_run_at(DateTime::from_timestamp(run_at, 0).unwrap_or_default());
51
52 if let Ok(max_attempts) = row.try_get("max_attempts") {
53 context.set_max_attempts(max_attempts)
54 }
55
56 let done_at: Option<i64> = row.try_get("done_at").unwrap_or_default();
57 context.set_done_at(done_at);
58
59 let lock_at: Option<i64> = row.try_get("lock_at").unwrap_or_default();
60 context.set_lock_at(lock_at);
61
62 let last_error = row.try_get("last_error").unwrap_or_default();
63 context.set_last_error(last_error);
64
65 let status: String = row.try_get("status")?;
66 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
67 index: "status".to_string(),
68 source: Box::new(e),
69 })?);
70
71 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
72 context.set_lock_by(
73 lock_by
74 .as_deref()
75 .map(WorkerId::from_str)
76 .transpose()
77 .map_err(|_| sqlx::Error::ColumnDecode {
78 index: "lock_by".to_string(),
79 source: "Could not parse lock_by as a WorkerId".into(),
80 })?,
81 );
82 parts.context = context;
83 Ok(SqlRequest {
84 req: Request::new_with_parts(job, parts),
85 _priv: (),
86 })
87 }
88}
89
90#[cfg(feature = "postgres")]
91#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
92impl<'r, T: Decode<'r, sqlx::Postgres> + Type<sqlx::Postgres>>
93 sqlx::FromRow<'r, sqlx::postgres::PgRow> for SqlRequest<T>
94{
95 fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
96 use chrono::Utc;
97 use sqlx::Row;
98 use std::str::FromStr;
99
100 let job: T = row.try_get("job")?;
101 let task_id: TaskId =
102 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
103 index: "id".to_string(),
104 source: Box::new(e),
105 })?;
106 let mut parts = Parts::<SqlContext>::default();
107 parts.task_id = task_id;
108
109 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
110 parts.attempt = Attempt::new_with_value(attempt as usize);
111 let mut context = SqlContext::new();
112
113 let run_at = row.try_get("run_at")?;
114 context.set_run_at(run_at);
115
116 if let Ok(max_attempts) = row.try_get("max_attempts") {
117 context.set_max_attempts(max_attempts)
118 }
119
120 let done_at: Option<chrono::DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
121 context.set_done_at(done_at.map(|d| d.timestamp()));
122
123 let lock_at: Option<chrono::DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
124 context.set_lock_at(lock_at.map(|d| d.timestamp()));
125
126 let last_error = row.try_get("last_error").unwrap_or_default();
127 context.set_last_error(last_error);
128
129 let status: String = row.try_get("status")?;
130 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
131 index: "job".to_string(),
132 source: Box::new(e),
133 })?);
134
135 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
136 context.set_lock_by(
137 lock_by
138 .as_deref()
139 .map(WorkerId::from_str)
140 .transpose()
141 .map_err(|_| sqlx::Error::ColumnDecode {
142 index: "lock_by".to_string(),
143 source: "Could not parse lock_by as a WorkerId".into(),
144 })?,
145 );
146 parts.context = context;
147 Ok(SqlRequest {
148 req: Request::new_with_parts(job, parts),
149 _priv: (),
150 })
151 }
152}
153
154#[cfg(feature = "mysql")]
155#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
156impl<'r, T: Decode<'r, sqlx::MySql> + Type<sqlx::MySql>> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>
157 for SqlRequest<T>
158{
159 fn from_row(row: &'r sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
160 use sqlx::Row;
161 use std::str::FromStr;
162 let job: T = row.try_get("job")?;
163 let task_id: TaskId =
164 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
165 index: "id".to_string(),
166 source: Box::new(e),
167 })?;
168 let mut parts = Parts::<SqlContext>::default();
169 parts.task_id = task_id;
170
171 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
172 parts.attempt = Attempt::new_with_value(attempt as usize);
173
174 let mut context = SqlContext::new();
175
176 let run_at = row.try_get("run_at")?;
177 context.set_run_at(run_at);
178
179 if let Ok(max_attempts) = row.try_get("max_attempts") {
180 context.set_max_attempts(max_attempts)
181 }
182
183 let done_at: Option<chrono::NaiveDateTime> = row.try_get("done_at").unwrap_or_default();
184 context.set_done_at(done_at.map(|d| d.and_utc().timestamp()));
185
186 let lock_at: Option<chrono::NaiveDateTime> = row.try_get("lock_at").unwrap_or_default();
187 context.set_lock_at(lock_at.map(|d| d.and_utc().timestamp()));
188
189 let last_error = row.try_get("last_error").unwrap_or_default();
190 context.set_last_error(last_error);
191
192 let status: String = row.try_get("status")?;
193 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
194 index: "job".to_string(),
195 source: Box::new(e),
196 })?);
197
198 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
199 context.set_lock_by(
200 lock_by
201 .as_deref()
202 .map(WorkerId::from_str)
203 .transpose()
204 .map_err(|_| sqlx::Error::ColumnDecode {
205 index: "lock_by".to_string(),
206 source: "Could not parse lock_by as a WorkerId".into(),
207 })?,
208 );
209 parts.context = context;
210 Ok(SqlRequest {
211 req: Request::new_with_parts(job, parts),
212 _priv: (),
213 })
214 }
215}