Skip to main content

resp_async/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7use bytes::{BufMut, Bytes, BytesMut};
8use tokio::sync::mpsc;
9
10use crate::resp::Value;
11use crate::response::RespError;
12
13/// A parsed Redis command with normalized name and arguments.
14#[derive(Debug, Clone, Eq, PartialEq)]
15pub struct Command {
16    pub name: Bytes,
17    pub name_upper: Bytes,
18    pub args: Vec<Value>,
19}
20
21impl Command {
22    pub fn new(name: Bytes, args: Vec<Value>) -> Self {
23        let name_upper = normalize_command_name(&name);
24        Self {
25            name,
26            name_upper,
27            args,
28        }
29    }
30
31    pub fn from_value(value: Value) -> Result<Self, RespError> {
32        match value {
33            Value::Array(mut items) => {
34                if items.is_empty() {
35                    return Err(RespError::invalid_data("ERR empty command"));
36                }
37                let name_value = items.remove(0);
38                let name = match name_value {
39                    Value::Bulk(b) | Value::Simple(b) => b,
40                    other => {
41                        return Err(RespError::invalid_data(format!(
42                            "ERR invalid command name: {:?}",
43                            other
44                        )));
45                    }
46                };
47                Ok(Command::new(name, items))
48            }
49            other => Err(RespError::invalid_data(format!(
50                "ERR expected array, got {:?}",
51                other
52            ))),
53        }
54    }
55}
56
57fn normalize_command_name(name: &Bytes) -> Bytes {
58    let mut needs = false;
59    for &b in name.iter() {
60        if b.is_ascii_lowercase() {
61            needs = true;
62            break;
63        }
64    }
65    if !needs {
66        return name.clone();
67    }
68    let mut buf = BytesMut::with_capacity(name.len());
69    for &b in name.iter() {
70        buf.put_u8(b.to_ascii_uppercase());
71    }
72    buf.freeze()
73}
74
75/// Per-request context passed to handlers.
76#[derive(Debug, Clone)]
77pub struct RequestContext {
78    pub command: Command,
79    pub peer_addr: SocketAddr,
80    pub local_addr: SocketAddr,
81    pub client_id: u64,
82    pub extensions: Extensions,
83    pub push: PushHandle,
84    pub pubsub: PubSubHandle,
85}
86
87/// Typed extensions map stored in the request context.
88#[derive(Debug, Default, Clone)]
89pub struct Extensions {
90    inner: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
91}
92
93impl Extensions {
94    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
95        self.inner.insert(TypeId::of::<T>(), Arc::new(value));
96    }
97
98    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
99        self.inner
100            .get(&TypeId::of::<T>())
101            .and_then(|value| value.as_ref().downcast_ref::<T>())
102    }
103
104    /// Returns a mutable reference when this entry is uniquely owned.
105    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
106        self.inner
107            .get_mut(&TypeId::of::<T>())
108            .and_then(|value| Arc::get_mut(value))
109            .and_then(|value| value.downcast_mut::<T>())
110    }
111}
112
113/// Handle for sending out-of-band push responses.
114#[derive(Debug, Clone)]
115pub struct PushHandle {
116    tx: mpsc::Sender<Value>,
117    close_tx: mpsc::Sender<()>,
118}
119
120impl PushHandle {
121    pub(crate) fn new(tx: mpsc::Sender<Value>, close_tx: mpsc::Sender<()>) -> Self {
122        Self { tx, close_tx }
123    }
124
125    pub async fn send(&self, value: Value) -> Result<(), PushError> {
126        match self.tx.try_send(value) {
127            Ok(()) => Ok(()),
128            Err(mpsc::error::TrySendError::Full(_)) => {
129                let _ = self.close_tx.try_send(());
130                Err(PushError::Full)
131            }
132            Err(mpsc::error::TrySendError::Closed(_)) => Err(PushError::Closed),
133        }
134    }
135}
136
137/// Errors when sending push messages.
138#[derive(Debug, Clone, Copy, Eq, PartialEq)]
139pub enum PushError {
140    Full,
141    Closed,
142}
143
144/// Handle for updating Pub/Sub subscription count.
145#[derive(Debug, Clone)]
146pub struct PubSubHandle {
147    count: Arc<AtomicUsize>,
148}
149
150impl PubSubHandle {
151    pub(crate) fn new(count: Arc<AtomicUsize>) -> Self {
152        Self { count }
153    }
154
155    pub fn set(&self, count: usize) {
156        self.count.store(count, Ordering::Release);
157    }
158
159    pub fn increment(&self) -> usize {
160        self.count.fetch_add(1, Ordering::AcqRel) + 1
161    }
162
163    pub fn decrement(&self) -> usize {
164        let prev = self
165            .count
166            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |value| {
167                if value == 0 { Some(0) } else { Some(value - 1) }
168            })
169            .unwrap_or(0);
170        prev.saturating_sub(1)
171    }
172
173    pub fn count(&self) -> usize {
174        self.count.load(Ordering::Acquire)
175    }
176}
177
178/// Extractor wrapper for command.
179#[derive(Debug, Clone)]
180pub struct Cmd(pub Command);
181
182/// Extractor wrapper for shared application state.
183#[derive(Debug, Clone)]
184pub struct State<T>(pub Arc<T>);
185
186/// Extractor wrapper for peer address.
187#[derive(Debug, Clone, Copy)]
188pub struct PeerAddr(pub SocketAddr);
189
190/// Extractor wrapper for local address.
191#[derive(Debug, Clone, Copy)]
192pub struct LocalAddr(pub SocketAddr);
193
194/// Extractor wrapper for client id.
195#[derive(Debug, Clone, Copy)]
196pub struct ClientId(pub u64);
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use bytes::Bytes;
202
203    #[test]
204    fn command_from_value_normalizes() {
205        let value = Value::Array(vec![
206            Value::Bulk(Bytes::from_static(b"ping")),
207            Value::Bulk(Bytes::from_static(b"hi")),
208        ]);
209        let cmd = Command::from_value(value).unwrap();
210        assert_eq!(cmd.name_upper.as_ref(), b"PING");
211        assert_eq!(cmd.args.len(), 1);
212    }
213
214    #[test]
215    fn extensions_insert_get_mut() {
216        let mut ext = Extensions::default();
217        ext.insert(42usize);
218        assert_eq!(ext.get::<usize>(), Some(&42));
219        if let Some(value) = ext.get_mut::<usize>() {
220            *value = 43;
221        }
222        assert_eq!(ext.get::<usize>(), Some(&43));
223
224        let _clone = ext.clone();
225        assert!(ext.get_mut::<usize>().is_none());
226    }
227
228    #[test]
229    fn pubsub_handle_counts() {
230        let count = Arc::new(AtomicUsize::new(0));
231        let handle = PubSubHandle::new(count);
232        assert_eq!(handle.count(), 0);
233        assert_eq!(handle.increment(), 1);
234        assert_eq!(handle.decrement(), 0);
235        handle.set(3);
236        assert_eq!(handle.count(), 3);
237    }
238}