Skip to main content

scheduler/
valkey_store.rs

1use crate::error::StoreErrorKind;
2use crate::model::JobState;
3use crate::store::{ResilientStateStore, ResilientStoreError, StateStore};
4use redis::{AsyncCommands, Client, ErrorKind, ServerErrorKind, aio::ConnectionManager};
5use std::error::Error;
6use std::fmt::{self, Display, Formatter};
7
8const DEFAULT_KEY_PREFIX: &str = "scheduler:valkey:job-state:";
9const LEGACY_DEFAULT_KEY_PREFIX: &str = "scheduler:job-state:";
10
11#[derive(Debug, Clone)]
12pub struct ValkeyStateStore {
13    connection: ConnectionManager,
14    key_prefix: String,
15}
16
17impl ValkeyStateStore {
18    pub async fn new(url: impl AsRef<str>) -> Result<Self, redis::RedisError> {
19        Self::with_prefix(url, DEFAULT_KEY_PREFIX).await
20    }
21
22    /// Creates a Valkey-backed store that permanently falls back to an
23    /// in-process mirror after connection-class failures.
24    pub async fn resilient(
25        url: impl AsRef<str>,
26    ) -> Result<ResilientStateStore<Self>, ValkeyStoreError> {
27        Self::with_prefix_resilient(url, DEFAULT_KEY_PREFIX).await
28    }
29
30    pub async fn with_prefix(
31        url: impl AsRef<str>,
32        key_prefix: impl Into<String>,
33    ) -> Result<Self, redis::RedisError> {
34        let client = Client::open(url.as_ref())?;
35        Self::from_client(client, key_prefix).await
36    }
37
38    pub async fn from_client(
39        client: Client,
40        key_prefix: impl Into<String>,
41    ) -> Result<Self, redis::RedisError> {
42        let connection = client.get_connection_manager().await?;
43        Ok(Self {
44            connection,
45            key_prefix: key_prefix.into(),
46        })
47    }
48
49    pub async fn with_prefix_resilient(
50        url: impl AsRef<str>,
51        key_prefix: impl Into<String>,
52    ) -> Result<ResilientStateStore<Self>, ValkeyStoreError> {
53        ResilientStateStore::from_result(
54            Self::with_prefix(url, key_prefix)
55                .await
56                .map_err(ValkeyStoreError::from),
57        )
58    }
59
60    pub async fn from_client_resilient(
61        client: Client,
62        key_prefix: impl Into<String>,
63    ) -> Result<ResilientStateStore<Self>, ValkeyStoreError> {
64        ResilientStateStore::from_result(
65            Self::from_client(client, key_prefix)
66                .await
67                .map_err(ValkeyStoreError::from),
68        )
69    }
70
71    fn state_key(&self, job_id: &str) -> String {
72        state_key(&self.key_prefix, job_id)
73    }
74
75    fn legacy_state_key(&self, job_id: &str) -> Option<String> {
76        if self.key_prefix == DEFAULT_KEY_PREFIX {
77            Some(state_key(LEGACY_DEFAULT_KEY_PREFIX, job_id))
78        } else {
79            None
80        }
81    }
82}
83
84fn state_key(prefix: &str, job_id: &str) -> String {
85    format!("{prefix}{job_id}")
86}
87
88#[derive(Debug)]
89pub enum ValkeyStoreError {
90    Redis(redis::RedisError),
91    Codec(serde_json::Error),
92}
93
94impl Display for ValkeyStoreError {
95    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
96        match self {
97            ValkeyStoreError::Redis(error) => write!(f, "{error}"),
98            ValkeyStoreError::Codec(error) => write!(f, "{error}"),
99        }
100    }
101}
102
103impl Error for ValkeyStoreError {
104    fn source(&self) -> Option<&(dyn Error + 'static)> {
105        match self {
106            ValkeyStoreError::Redis(error) => Some(error),
107            ValkeyStoreError::Codec(error) => Some(error),
108        }
109    }
110}
111
112impl From<redis::RedisError> for ValkeyStoreError {
113    fn from(error: redis::RedisError) -> Self {
114        Self::Redis(error)
115    }
116}
117
118impl From<serde_json::Error> for ValkeyStoreError {
119    fn from(error: serde_json::Error) -> Self {
120        Self::Codec(error)
121    }
122}
123
124impl ValkeyStoreError {
125    pub fn is_connection_issue(&self) -> bool {
126        match self {
127            Self::Redis(error) => {
128                error.is_connection_dropped()
129                    || error.is_connection_refusal()
130                    || error.is_timeout()
131                    || matches!(
132                        error.kind(),
133                        ErrorKind::Io
134                            | ErrorKind::ClusterConnectionNotFound
135                            | ErrorKind::Server(ServerErrorKind::BusyLoading)
136                            | ErrorKind::Server(ServerErrorKind::ClusterDown)
137                            | ErrorKind::Server(ServerErrorKind::MasterDown)
138                            | ErrorKind::Server(ServerErrorKind::TryAgain)
139                    )
140            }
141            Self::Codec(_) => false,
142        }
143    }
144}
145
146impl ResilientStoreError for ValkeyStoreError {
147    fn is_connection_issue(&self) -> bool {
148        self.is_connection_issue()
149    }
150}
151
152impl StateStore for ValkeyStateStore {
153    type Error = ValkeyStoreError;
154
155    async fn load(&self, job_id: &str) -> Result<Option<JobState>, Self::Error> {
156        let mut connection = self.connection.clone();
157        let payload: Option<String> = connection
158            .get(self.state_key(job_id))
159            .await
160            .map_err(ValkeyStoreError::from)?;
161
162        let payload = match payload {
163            Some(payload) => Some(payload),
164            None => {
165                if let Some(legacy_key) = self.legacy_state_key(job_id) {
166                    connection
167                        .get(legacy_key)
168                        .await
169                        .map_err(ValkeyStoreError::from)?
170                } else {
171                    None
172                }
173            }
174        };
175
176        payload
177            .map(|value| serde_json::from_str(&value).map_err(ValkeyStoreError::from))
178            .transpose()
179    }
180
181    async fn save(&self, state: &JobState) -> Result<(), Self::Error> {
182        let mut connection = self.connection.clone();
183        let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
184        connection
185            .set(self.state_key(&state.job_id), payload)
186            .await
187            .map_err(ValkeyStoreError::from)
188    }
189
190    async fn delete(&self, job_id: &str) -> Result<(), Self::Error> {
191        let mut connection = self.connection.clone();
192        let _: usize = connection
193            .del(self.state_key(job_id))
194            .await
195            .map_err(ValkeyStoreError::from)?;
196
197        if let Some(legacy_key) = self.legacy_state_key(job_id) {
198            let _: usize = connection
199                .del(legacy_key)
200                .await
201                .map_err(ValkeyStoreError::from)?;
202        }
203
204        Ok(())
205    }
206
207    fn classify_error(error: &Self::Error) -> StoreErrorKind
208    where
209        Self: Sized,
210    {
211        if matches!(error, ValkeyStoreError::Codec(_)) {
212            StoreErrorKind::Data
213        } else if error.is_connection_issue() {
214            StoreErrorKind::Connection
215        } else {
216            StoreErrorKind::Unknown
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{DEFAULT_KEY_PREFIX, LEGACY_DEFAULT_KEY_PREFIX, ValkeyStoreError, state_key};
224    use crate::model::JobState;
225    use chrono::{TimeDelta, Utc};
226
227    #[test]
228    fn state_key_uses_custom_prefix() {
229        assert_eq!(state_key("custom:", "job-1"), "custom:job-1");
230        assert_eq!(
231            state_key(DEFAULT_KEY_PREFIX, "job-2"),
232            "scheduler:valkey:job-state:job-2"
233        );
234    }
235
236    #[test]
237    fn legacy_default_prefix_is_stable() {
238        assert_eq!(
239            state_key(LEGACY_DEFAULT_KEY_PREFIX, "job-3"),
240            "scheduler:job-state:job-3"
241        );
242    }
243
244    #[test]
245    fn job_state_json_round_trip() {
246        let state = JobState {
247            job_id: "job-1".to_string(),
248            trigger_count: 2,
249            last_run_at: Some(Utc::now()),
250            last_success_at: Some(Utc::now() + TimeDelta::seconds(1)),
251            next_run_at: Some(Utc::now() + TimeDelta::seconds(5)),
252            last_error: Some("boom".to_string()),
253        };
254
255        let encoded = serde_json::to_string(&state).unwrap();
256        let decoded: JobState = serde_json::from_str(&encoded).unwrap();
257
258        assert_eq!(decoded, state);
259    }
260
261    #[test]
262    fn io_errors_are_classified_as_connection_issues() {
263        let error = ValkeyStoreError::from(redis::RedisError::from(std::io::Error::from(
264            std::io::ErrorKind::BrokenPipe,
265        )));
266
267        assert!(error.is_connection_issue());
268    }
269
270    #[test]
271    fn timeout_errors_are_classified_as_connection_issues() {
272        let error = ValkeyStoreError::from(redis::RedisError::from(std::io::Error::from(
273            std::io::ErrorKind::TimedOut,
274        )));
275
276        assert!(error.is_connection_issue());
277    }
278
279    #[test]
280    fn codec_errors_are_not_classified_as_connection_issues() {
281        let error = ValkeyStoreError::from(serde_json::from_str::<JobState>("{").unwrap_err());
282
283        assert!(!error.is_connection_issue());
284    }
285}