1use apalis_core::request::Parts;
2use apalis_core::task::attempt::Attempt;
3use apalis_core::task::task_id::TaskId;
4use apalis_core::{request::Request, worker::WorkerId};
5
6use serde::{Deserialize, Serialize};
7use sqlx::{Decode, Type};
8
9use crate::context::SqlContext;
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SqlRequest<T> {
13 pub req: Request<T, SqlContext>,
15 pub(crate) _priv: (),
16}
17
18impl<T> SqlRequest<T> {
19 pub fn new(req: Request<T, SqlContext>) -> Self {
21 SqlRequest { req, _priv: () }
22 }
23}
24
25#[cfg(feature = "sqlite")]
26#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
27impl<'r, T: Decode<'r, sqlx::Sqlite> + Type<sqlx::Sqlite>>
28 sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for SqlRequest<T>
29{
30 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
31 use chrono::DateTime;
32 use sqlx::Row;
33 use std::str::FromStr;
34
35 let job: T = row.try_get("job")?;
36 let task_id: TaskId =
37 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
38 index: "id".to_string(),
39 source: Box::new(e),
40 })?;
41 let mut parts = Parts::<SqlContext>::default();
42 parts.task_id = task_id;
43
44 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
45 parts.attempt = Attempt::new_with_value(attempt as usize);
46
47 let mut context = crate::context::SqlContext::new();
48
49 let run_at: i64 = row.try_get("run_at")?;
50 context.set_run_at(DateTime::from_timestamp(run_at, 0).unwrap_or_default());
51
52 if let Ok(max_attempts) = row.try_get("max_attempts") {
53 context.set_max_attempts(max_attempts)
54 }
55
56 let done_at: Option<i64> = row.try_get("done_at").unwrap_or_default();
57 context.set_done_at(done_at);
58
59 let lock_at: Option<i64> = row.try_get("lock_at").unwrap_or_default();
60 context.set_lock_at(lock_at);
61
62 let last_error = row.try_get("last_error").unwrap_or_default();
63 context.set_last_error(last_error);
64
65 let status: String = row.try_get("status")?;
66 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
67 index: "status".to_string(),
68 source: Box::new(e),
69 })?);
70
71 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
72 context.set_lock_by(
73 lock_by
74 .as_deref()
75 .map(WorkerId::from_str)
76 .transpose()
77 .map_err(|_| sqlx::Error::ColumnDecode {
78 index: "lock_by".to_string(),
79 source: "Could not parse lock_by as a WorkerId".into(),
80 })?,
81 );
82
83 let priority: i32 = row.try_get("priority").unwrap_or_default();
84 context.set_priority(priority);
85
86 parts.context = context;
87 Ok(SqlRequest {
88 req: Request::new_with_parts(job, parts),
89 _priv: (),
90 })
91 }
92}
93
94#[cfg(feature = "postgres")]
95#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
96impl<'r, T: Decode<'r, sqlx::Postgres> + Type<sqlx::Postgres>>
97 sqlx::FromRow<'r, sqlx::postgres::PgRow> for SqlRequest<T>
98{
99 fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
100 use chrono::Utc;
101 use sqlx::Row;
102 use std::str::FromStr;
103
104 let job: T = row.try_get("job")?;
105 let task_id: TaskId =
106 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
107 index: "id".to_string(),
108 source: Box::new(e),
109 })?;
110 let mut parts = Parts::<SqlContext>::default();
111 parts.task_id = task_id;
112
113 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
114 parts.attempt = Attempt::new_with_value(attempt as usize);
115 let mut context = SqlContext::new();
116
117 let run_at = row.try_get("run_at")?;
118 context.set_run_at(run_at);
119
120 if let Ok(max_attempts) = row.try_get("max_attempts") {
121 context.set_max_attempts(max_attempts)
122 }
123
124 let done_at: Option<chrono::DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
125 context.set_done_at(done_at.map(|d| d.timestamp()));
126
127 let lock_at: Option<chrono::DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
128 context.set_lock_at(lock_at.map(|d| d.timestamp()));
129
130 let last_error = row.try_get("last_error").unwrap_or_default();
131 context.set_last_error(last_error);
132
133 let status: String = row.try_get("status")?;
134 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
135 index: "job".to_string(),
136 source: Box::new(e),
137 })?);
138
139 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
140 context.set_lock_by(
141 lock_by
142 .as_deref()
143 .map(WorkerId::from_str)
144 .transpose()
145 .map_err(|_| sqlx::Error::ColumnDecode {
146 index: "lock_by".to_string(),
147 source: "Could not parse lock_by as a WorkerId".into(),
148 })?,
149 );
150
151 let priority: i32 = row.try_get("priority").unwrap_or_default();
152 context.set_priority(priority);
153
154 parts.context = context;
155 Ok(SqlRequest {
156 req: Request::new_with_parts(job, parts),
157 _priv: (),
158 })
159 }
160}
161
162#[cfg(feature = "mysql")]
163#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
164impl<'r, T: Decode<'r, sqlx::MySql> + Type<sqlx::MySql>> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>
165 for SqlRequest<T>
166{
167 fn from_row(row: &'r sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
168 use sqlx::Row;
169 use std::str::FromStr;
170 let job: T = row.try_get("job")?;
171 let task_id: TaskId =
172 TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
173 index: "id".to_string(),
174 source: Box::new(e),
175 })?;
176 let mut parts = Parts::<SqlContext>::default();
177 parts.task_id = task_id;
178
179 let attempt: i32 = row.try_get("attempts").unwrap_or(0);
180 parts.attempt = Attempt::new_with_value(attempt as usize);
181
182 let mut context = SqlContext::new();
183
184 let run_at = row.try_get("run_at")?;
185 context.set_run_at(run_at);
186
187 if let Ok(max_attempts) = row.try_get("max_attempts") {
188 context.set_max_attempts(max_attempts)
189 }
190
191 let done_at: Option<chrono::NaiveDateTime> = row.try_get("done_at").unwrap_or_default();
192 context.set_done_at(done_at.map(|d| d.and_utc().timestamp()));
193
194 let lock_at: Option<chrono::NaiveDateTime> = row.try_get("lock_at").unwrap_or_default();
195 context.set_lock_at(lock_at.map(|d| d.and_utc().timestamp()));
196
197 let last_error = row.try_get("last_error").unwrap_or_default();
198 context.set_last_error(last_error);
199
200 let status: String = row.try_get("status")?;
201 context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
202 index: "job".to_string(),
203 source: Box::new(e),
204 })?);
205
206 let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
207 context.set_lock_by(
208 lock_by
209 .as_deref()
210 .map(WorkerId::from_str)
211 .transpose()
212 .map_err(|_| sqlx::Error::ColumnDecode {
213 index: "lock_by".to_string(),
214 source: "Could not parse lock_by as a WorkerId".into(),
215 })?,
216 );
217
218 let priority: i32 = row.try_get("priority").unwrap_or_default();
219 context.set_priority(priority);
220
221 parts.context = context;
222 Ok(SqlRequest {
223 req: Request::new_with_parts(job, parts),
224 _priv: (),
225 })
226 }
227}