1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7use bytes::{BufMut, Bytes, BytesMut};
8use tokio::sync::mpsc;
9
10use crate::resp::Value;
11use crate::response::RespError;
12
13#[derive(Debug, Clone, Eq, PartialEq)]
15pub struct Command {
16 pub name: Bytes,
17 pub name_upper: Bytes,
18 pub args: Vec<Value>,
19}
20
21impl Command {
22 pub fn new(name: Bytes, args: Vec<Value>) -> Self {
23 let name_upper = normalize_command_name(&name);
24 Self {
25 name,
26 name_upper,
27 args,
28 }
29 }
30
31 pub fn from_value(value: Value) -> Result<Self, RespError> {
32 match value {
33 Value::Array(mut items) => {
34 if items.is_empty() {
35 return Err(RespError::invalid_data("ERR empty command"));
36 }
37 let name_value = items.remove(0);
38 let name = match name_value {
39 Value::Bulk(b) | Value::Simple(b) => b,
40 other => {
41 return Err(RespError::invalid_data(format!(
42 "ERR invalid command name: {:?}",
43 other
44 )));
45 }
46 };
47 Ok(Command::new(name, items))
48 }
49 other => Err(RespError::invalid_data(format!(
50 "ERR expected array, got {:?}",
51 other
52 ))),
53 }
54 }
55}
56
57fn normalize_command_name(name: &Bytes) -> Bytes {
58 let mut needs = false;
59 for &b in name.iter() {
60 if b.is_ascii_lowercase() {
61 needs = true;
62 break;
63 }
64 }
65 if !needs {
66 return name.clone();
67 }
68 let mut buf = BytesMut::with_capacity(name.len());
69 for &b in name.iter() {
70 buf.put_u8(b.to_ascii_uppercase());
71 }
72 buf.freeze()
73}
74
75#[derive(Debug, Clone)]
77pub struct RequestContext {
78 pub command: Command,
79 pub peer_addr: SocketAddr,
80 pub local_addr: SocketAddr,
81 pub client_id: u64,
82 pub extensions: Extensions,
83 pub push: PushHandle,
84 pub pubsub: PubSubHandle,
85}
86
87#[derive(Debug, Default, Clone)]
89pub struct Extensions {
90 inner: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
91}
92
93impl Extensions {
94 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
95 self.inner.insert(TypeId::of::<T>(), Arc::new(value));
96 }
97
98 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
99 self.inner
100 .get(&TypeId::of::<T>())
101 .and_then(|value| value.as_ref().downcast_ref::<T>())
102 }
103
104 pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
106 self.inner
107 .get_mut(&TypeId::of::<T>())
108 .and_then(|value| Arc::get_mut(value))
109 .and_then(|value| value.downcast_mut::<T>())
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct PushHandle {
116 tx: mpsc::Sender<Value>,
117 close_tx: mpsc::Sender<()>,
118}
119
120impl PushHandle {
121 pub(crate) fn new(tx: mpsc::Sender<Value>, close_tx: mpsc::Sender<()>) -> Self {
122 Self { tx, close_tx }
123 }
124
125 pub async fn send(&self, value: Value) -> Result<(), PushError> {
126 match self.tx.try_send(value) {
127 Ok(()) => Ok(()),
128 Err(mpsc::error::TrySendError::Full(_)) => {
129 let _ = self.close_tx.try_send(());
130 Err(PushError::Full)
131 }
132 Err(mpsc::error::TrySendError::Closed(_)) => Err(PushError::Closed),
133 }
134 }
135}
136
137#[derive(Debug, Clone, Copy, Eq, PartialEq)]
139pub enum PushError {
140 Full,
141 Closed,
142}
143
144#[derive(Debug, Clone)]
146pub struct PubSubHandle {
147 count: Arc<AtomicUsize>,
148}
149
150impl PubSubHandle {
151 pub(crate) fn new(count: Arc<AtomicUsize>) -> Self {
152 Self { count }
153 }
154
155 pub fn set(&self, count: usize) {
156 self.count.store(count, Ordering::Release);
157 }
158
159 pub fn increment(&self) -> usize {
160 self.count.fetch_add(1, Ordering::AcqRel) + 1
161 }
162
163 pub fn decrement(&self) -> usize {
164 let prev = self
165 .count
166 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |value| {
167 if value == 0 { Some(0) } else { Some(value - 1) }
168 })
169 .unwrap_or(0);
170 prev.saturating_sub(1)
171 }
172
173 pub fn count(&self) -> usize {
174 self.count.load(Ordering::Acquire)
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct Cmd(pub Command);
181
182#[derive(Debug, Clone)]
184pub struct State<T>(pub Arc<T>);
185
186#[derive(Debug, Clone, Copy)]
188pub struct PeerAddr(pub SocketAddr);
189
190#[derive(Debug, Clone, Copy)]
192pub struct LocalAddr(pub SocketAddr);
193
194#[derive(Debug, Clone, Copy)]
196pub struct ClientId(pub u64);
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use bytes::Bytes;
202
203 #[test]
204 fn command_from_value_normalizes() {
205 let value = Value::Array(vec![
206 Value::Bulk(Bytes::from_static(b"ping")),
207 Value::Bulk(Bytes::from_static(b"hi")),
208 ]);
209 let cmd = Command::from_value(value).unwrap();
210 assert_eq!(cmd.name_upper.as_ref(), b"PING");
211 assert_eq!(cmd.args.len(), 1);
212 }
213
214 #[test]
215 fn extensions_insert_get_mut() {
216 let mut ext = Extensions::default();
217 ext.insert(42usize);
218 assert_eq!(ext.get::<usize>(), Some(&42));
219 if let Some(value) = ext.get_mut::<usize>() {
220 *value = 43;
221 }
222 assert_eq!(ext.get::<usize>(), Some(&43));
223
224 let _clone = ext.clone();
225 assert!(ext.get_mut::<usize>().is_none());
226 }
227
228 #[test]
229 fn pubsub_handle_counts() {
230 let count = Arc::new(AtomicUsize::new(0));
231 let handle = PubSubHandle::new(count);
232 assert_eq!(handle.count(), 0);
233 assert_eq!(handle.increment(), 1);
234 assert_eq!(handle.decrement(), 0);
235 handle.set(3);
236 assert_eq!(handle.count(), 3);
237 }
238}