Skip to main content

fast_cache/commands/
getex.rs

1//! GETEX command parsing and execution.
2
3use crate::commands::EngineCommandDispatch;
4#[cfg(feature = "server")]
5use crate::protocol::FastCodec;
6use crate::protocol::{FastCommand, FastRequest, FastResponse, Frame};
7#[cfg(feature = "server")]
8use crate::server::commands::{
9    BorrowedCommandContext, DirectCommandContext, DirectFastCommand, FastCommandContext,
10    FastDirectCommand, FcnpCommandContext, FcnpDirectCommand, FcnpDispatch, RawCommandContext,
11    RawDirectCommand,
12};
13#[cfg(feature = "server")]
14use crate::server::wire::ServerWire;
15use crate::storage::{
16    Command, EngineCommandContext, EngineFastFuture, EngineFrameFuture, ExpirationChange, ShardKey,
17    ShardOperation, ShardReply, hash_key, now_millis,
18};
19use crate::{FastCacheError, Result};
20
21use super::DecodedFastCommand;
22use super::parsing::{CommandArity, TtlMillis};
23
24pub(crate) struct GetEx;
25pub(crate) static COMMAND: GetEx = GetEx;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28enum GetExExpiration {
29    Keep,
30    ExpireIn(u64),
31    Persist,
32}
33
34impl GetExExpiration {
35    fn to_engine(self) -> ExpirationChange {
36        match self {
37            Self::Keep => ExpirationChange::Keep,
38            Self::ExpireIn(ttl_ms) => ExpirationChange::ExpireAt(relative_expire_at_ms(ttl_ms)),
39            Self::Persist => ExpirationChange::Persist,
40        }
41    }
42
43    #[cfg(feature = "server")]
44    fn expire_at_ms(self, now_ms: u64) -> Option<Option<u64>> {
45        match self {
46            Self::Keep => None,
47            Self::ExpireIn(ttl_ms) => Some(Some(now_ms.saturating_add(ttl_ms))),
48            Self::Persist => Some(None),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub(crate) struct OwnedGetEx {
55    key: Vec<u8>,
56    expiration: GetExExpiration,
57}
58
59impl OwnedGetEx {
60    fn new(key: Vec<u8>, expiration: GetExExpiration) -> Self {
61        Self { key, expiration }
62    }
63}
64
65impl super::OwnedCommandData for OwnedGetEx {
66    type Spec = GetEx;
67
68    fn route_key(&self) -> Option<&[u8]> {
69        Some(&self.key)
70    }
71
72    fn to_borrowed_command(&self) -> super::BorrowedCommandBox<'_> {
73        Box::new(BorrowedGetEx::new(&self.key, self.expiration))
74    }
75}
76
77#[derive(Debug, Clone, Copy)]
78pub(crate) struct BorrowedGetEx<'a> {
79    key: &'a [u8],
80    expiration: GetExExpiration,
81}
82
83impl<'a> BorrowedGetEx<'a> {
84    fn new(key: &'a [u8], expiration: GetExExpiration) -> Self {
85        Self { key, expiration }
86    }
87}
88
89impl<'a> super::BorrowedCommandData<'a> for BorrowedGetEx<'a> {
90    type Spec = GetEx;
91
92    fn route_key(&self) -> Option<&'a [u8]> {
93        Some(self.key)
94    }
95
96    fn to_owned_command(&self) -> Command {
97        Command::new(Box::new(OwnedGetEx::new(
98            self.key.to_vec(),
99            self.expiration,
100        )))
101    }
102
103    fn execute_engine<'b>(&'b self, ctx: EngineCommandContext<'b>) -> EngineFrameFuture<'b>
104    where
105        'a: 'b,
106    {
107        let key = self.key;
108        let expiration = self.expiration.to_engine();
109        Box::pin(async move { GetEx::execute_engine_frame(ctx, key, expiration).await })
110    }
111
112    #[cfg(feature = "server")]
113    fn execute_borrowed_frame(&self, store: &crate::storage::EmbeddedStore, now_ms: u64) -> Frame {
114        match Self::execute_embedded(store, self.key, self.expiration, now_ms) {
115            Some(value) => Frame::BlobString(value),
116            None => Frame::Null,
117        }
118    }
119
120    #[cfg(feature = "server")]
121    fn execute_borrowed(&self, ctx: BorrowedCommandContext<'_, '_, '_>) {
122        match Self::execute_embedded(ctx.store, self.key, self.expiration, now_millis()) {
123            Some(value) => ServerWire::write_resp_blob_string(ctx.out, &value),
124            None => ctx.out.extend_from_slice(b"$-1\r\n"),
125        }
126    }
127
128    #[cfg(feature = "server")]
129    fn execute_direct_borrowed(&self, ctx: DirectCommandContext) -> Frame {
130        let expire_at_ms = self.expiration.expire_at_ms(ctx.now_ms);
131        let value = match expire_at_ms {
132            Some(expire_at_ms) => ctx.getex(self.key, expire_at_ms),
133            None => ctx.get(self.key),
134        };
135        value.map_or(Frame::Null, Frame::BlobString)
136    }
137}
138
139impl super::CommandSpec for GetEx {
140    const NAME: &'static str = "GETEX";
141    const MUTATES_VALUE: bool = true;
142}
143
144impl super::OwnedCommandParse for GetEx {
145    fn parse_owned(parts: &[Vec<u8>]) -> Result<Command> {
146        CommandArity::<Self>::at_least(parts.len(), 2, "key")?;
147        Ok(Command::new(Box::new(OwnedGetEx::new(
148            parts[1].clone(),
149            parse_getex_expiration(&parts[2..])?,
150        ))))
151    }
152}
153
154impl<'a> super::BorrowedCommandParse<'a> for GetEx {
155    fn parse_borrowed(parts: &[&'a [u8]]) -> Result<super::BorrowedCommandBox<'a>> {
156        CommandArity::<Self>::at_least(parts.len(), 2, "key")?;
157        Ok(Box::new(BorrowedGetEx::new(
158            parts[1],
159            parse_getex_expiration(&parts[2..])?,
160        )))
161    }
162}
163
164impl DecodedFastCommand for GetEx {
165    fn matches_decoded_fast(&self, command: &FastCommand<'_>) -> bool {
166        matches!(command, FastCommand::GetEx { .. })
167    }
168}
169
170impl EngineCommandDispatch for GetEx {
171    fn execute_engine_fast<'a>(
172        &'static self,
173        ctx: EngineCommandContext<'a>,
174        request: FastRequest<'a>,
175    ) -> EngineFastFuture<'a> {
176        Box::pin(async move {
177            match request.command {
178                FastCommand::GetEx { key, ttl_ms } => {
179                    GetEx::execute_engine_response(
180                        ctx,
181                        key,
182                        ExpirationChange::ExpireAt(relative_expire_at_ms(ttl_ms)),
183                    )
184                    .await
185                }
186                _ => Ok(FastResponse::Error(b"ERR unsupported command".to_vec())),
187            }
188        })
189    }
190}
191
192impl GetEx {
193    async fn execute_engine_frame(
194        ctx: EngineCommandContext<'_>,
195        key: &[u8],
196        expiration: ExpirationChange,
197    ) -> Result<Frame> {
198        match Self::load_and_update(ctx, key, expiration).await? {
199            Some(value) => Ok(Frame::BlobString(value)),
200            None => Ok(Frame::Null),
201        }
202    }
203
204    async fn execute_engine_response(
205        ctx: EngineCommandContext<'_>,
206        key: &[u8],
207        expiration: ExpirationChange,
208    ) -> Result<FastResponse> {
209        match Self::load_and_update(ctx, key, expiration).await? {
210            Some(value) => Ok(FastResponse::Value(value)),
211            None => Ok(FastResponse::Null),
212        }
213    }
214
215    async fn load_and_update(
216        ctx: EngineCommandContext<'_>,
217        key: &[u8],
218        expiration: ExpirationChange,
219    ) -> Result<Option<Vec<u8>>> {
220        let key_hash = hash_key(key);
221        let shard = ctx.route_key_hash(key_hash);
222        match ctx
223            .request(
224                shard,
225                ShardOperation::GetEx {
226                    key_hash,
227                    key: ShardKey::inline(key),
228                    expiration,
229                },
230            )
231            .await?
232        {
233            ShardReply::Value(value) => Ok(value),
234            _ => Err(FastCacheError::Command(
235                "GETEX received unexpected shard reply".into(),
236            )),
237        }
238    }
239}
240
241#[cfg(feature = "server")]
242impl<'a> BorrowedGetEx<'a> {
243    fn execute_embedded(
244        store: &crate::storage::EmbeddedStore,
245        key: &[u8],
246        expiration: GetExExpiration,
247        now_ms: u64,
248    ) -> Option<Vec<u8>> {
249        let value = store.get(key);
250        if value.is_some() {
251            match expiration.expire_at_ms(now_ms) {
252                Some(Some(expire_at_ms)) => {
253                    store.expire(key, expire_at_ms);
254                }
255                Some(None) => {
256                    store.persist(key);
257                }
258                None => {}
259            }
260        }
261        value
262    }
263}
264
265fn parse_getex_expiration(args: &[impl AsRef<[u8]>]) -> Result<GetExExpiration> {
266    match args {
267        [] => Ok(GetExExpiration::Keep),
268        [option] if option.as_ref().eq_ignore_ascii_case(b"PERSIST") => {
269            Ok(GetExExpiration::Persist)
270        }
271        [option, value] if option.as_ref().eq_ignore_ascii_case(b"EX") => {
272            TtlMillis::<GetEx>::seconds(value.as_ref()).map(GetExExpiration::ExpireIn)
273        }
274        [option, value] if option.as_ref().eq_ignore_ascii_case(b"PX") => {
275            TtlMillis::<GetEx>::millis(value.as_ref()).map(GetExExpiration::ExpireIn)
276        }
277        _ => Err(FastCacheError::Command("GETEX syntax error".into())),
278    }
279}
280
281fn relative_expire_at_ms(ttl_ms: u64) -> u64 {
282    now_millis().saturating_add(ttl_ms)
283}
284
285#[cfg(feature = "server")]
286impl RawDirectCommand for GetEx {
287    fn execute(&self, ctx: RawCommandContext<'_, '_, '_>) {
288        match parse_getex_raw(ctx.args.as_slice()) {
289            Some((key, expiration)) => {
290                match BorrowedGetEx::execute_embedded(ctx.store, key, expiration, now_millis()) {
291                    Some(value) => ServerWire::write_resp_blob_string(ctx.out, &value),
292                    None => ctx.out.extend_from_slice(b"$-1\r\n"),
293                }
294            }
295            None => ServerWire::write_resp_error(ctx.out, "ERR syntax error"),
296        }
297    }
298}
299
300#[cfg(feature = "server")]
301fn parse_getex_raw<'a>(args: &'a [&'a [u8]]) -> Option<(&'a [u8], GetExExpiration)> {
302    match args {
303        [key] => Some((key, GetExExpiration::Keep)),
304        [key, option] if option.eq_ignore_ascii_case(b"PERSIST") => {
305            Some((key, GetExExpiration::Persist))
306        }
307        [key, option, value] if option.eq_ignore_ascii_case(b"EX") => {
308            TtlMillis::<()>::ascii_seconds(value)
309                .map(|ttl_ms| (*key, GetExExpiration::ExpireIn(ttl_ms)))
310        }
311        [key, option, value] if option.eq_ignore_ascii_case(b"PX") => {
312            TtlMillis::<()>::ascii_millis(value)
313                .map(|ttl_ms| (*key, GetExExpiration::ExpireIn(ttl_ms)))
314        }
315        _ => None,
316    }
317}
318
319#[cfg(feature = "server")]
320impl DirectFastCommand for GetEx {
321    fn execute_direct_fast(
322        &self,
323        ctx: DirectCommandContext,
324        request: FastRequest<'_>,
325    ) -> FastResponse {
326        match request.command {
327            FastCommand::GetEx { key, ttl_ms } => ctx
328                .getex(key, Some(ctx.now_ms.saturating_add(ttl_ms)))
329                .map_or(FastResponse::Null, FastResponse::Value),
330            _ => FastResponse::Error(b"ERR unsupported command".to_vec()),
331        }
332    }
333}
334
335#[cfg(feature = "server")]
336impl FastDirectCommand for GetEx {
337    fn execute_fast(&self, ctx: FastCommandContext<'_, '_>, command: FastCommand<'_>) {
338        match command {
339            FastCommand::GetEx { key, ttl_ms } => match ctx.store.get(key) {
340                Some(value) => {
341                    ctx.store.expire(key, relative_expire_at_ms(ttl_ms));
342                    ServerWire::write_fast_value(ctx.out, &value);
343                }
344                None => ServerWire::write_fast_null(ctx.out),
345            },
346            _ => ServerWire::write_fast_error(ctx.out, "ERR unsupported command"),
347        }
348    }
349}
350
351#[cfg(feature = "server")]
352impl FcnpDirectCommand for GetEx {
353    fn opcode(&self) -> u8 {
354        4
355    }
356
357    fn try_execute_fcnp(&self, ctx: FcnpCommandContext<'_, '_, '_, '_>) -> FcnpDispatch {
358        let frame_len = ctx.frame.frame_len;
359        let Ok(Some((request, consumed))) = FastCodec::decode_request(&ctx.frame.buf[..frame_len])
360        else {
361            return FcnpDispatch::Unsupported;
362        };
363        let FastCommand::GetEx { key, ttl_ms } = request.command else {
364            return FcnpDispatch::Unsupported;
365        };
366        let Some(key_hash) = request.key_hash else {
367            return FcnpDispatch::Unsupported;
368        };
369        if let Some(owned_shard_id) = ctx.owned_shard_id {
370            match request.route_shard.map(|shard| shard as usize) {
371                Some(route_shard)
372                    if route_shard == owned_shard_id
373                        && ctx.request_matches_owned_shard_for_key(route_shard, key_hash, key) => {}
374                _ => {
375                    ServerWire::write_fast_error(ctx.out, "ERR FCNP route shard mismatch");
376                    return FcnpDispatch::Complete(consumed);
377                }
378            }
379        }
380        match ctx.store.get(key) {
381            Some(value) => {
382                ctx.store.expire(key, relative_expire_at_ms(ttl_ms));
383                ServerWire::write_fast_value(ctx.out, &value);
384            }
385            None => ServerWire::write_fast_null(ctx.out),
386        }
387        FcnpDispatch::Complete(consumed)
388    }
389}