1use std::collections::HashMap;
2use std::sync::Arc;
3
4use kameo::actor::ActorRef;
5use libp2p::bytes::BytesMut;
6use redis_protocol::resp3;
7use redis_protocol::resp3::decode::complete::decode_bytes_mut;
8use redis_protocol::resp3::types::BytesFrame;
9use sierradb::bucket::BucketId;
10use sierradb::bucket::segment::EventRecord;
11use sierradb::cache::SegmentBlockCache;
12use sierradb_cluster::ClusterActor;
13use sierradb_cluster::subscription::SubscriptionEvent;
14use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
16use tokio::sync::{mpsc, watch};
17use tokio::task::JoinSet;
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, warn};
20use uuid::Uuid;
21
22use crate::request::{Command, encode_event, number, simple_str};
23
24pub struct Server {
25 cluster_ref: ActorRef<ClusterActor>,
26 caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
27 num_partitions: u16,
28 cache_capacity_bytes: usize,
29 shutdown: CancellationToken,
30 conns: JoinSet<io::Result<()>>,
31}
32
33impl Server {
34 pub fn new(
35 cluster_ref: ActorRef<ClusterActor>,
36 caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
37 num_partitions: u16,
38 cache_capacity_bytes: usize,
39 shutdown: CancellationToken,
40 ) -> Self {
41 Server {
42 cluster_ref,
43 caches,
44 num_partitions,
45 cache_capacity_bytes,
46 shutdown,
47 conns: JoinSet::new(),
48 }
49 }
50
51 pub async fn listen(mut self, addr: impl ToSocketAddrs) -> io::Result<JoinSet<io::Result<()>>> {
52 let listener = TcpListener::bind(addr).await?;
53 loop {
54 tokio::select! {
55 res = listener.accept() => {
56 match res {
57 Ok((stream, _)) => {
58 stream.set_nodelay(true)?;
59 let cluster_ref = self.cluster_ref.clone();
60 let caches = self.caches.clone();
61 let num_partitions = self.num_partitions;
62 let cache_capacity_bytes = self.cache_capacity_bytes;
63 let shutdown = self.shutdown.clone();
64 self.conns.spawn(async move {
65 let res = Conn::new(
66 cluster_ref,
67 caches,
68 num_partitions,
69 cache_capacity_bytes,
70 stream,
71 shutdown,
72 )
73 .run()
74 .await;
75 if let Err(err) = &res {
76 warn!("connection error: {err}");
77 }
78 res
79 });
80 }
81 Err(err) => warn!("failed to accept connection: {err}"),
82 }
83 }
84 _ = self.shutdown.cancelled() => {
85 return Ok(self.conns);
86 }
87 }
88 }
89 }
90}
91
92pub struct Conn {
93 pub cluster_ref: ActorRef<ClusterActor>,
94 pub caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
95 pub num_partitions: u16,
96 pub cache_capacity_bytes: usize,
97 pub stream: TcpStream,
98 pub shutdown: CancellationToken,
99 pub read: BytesMut,
100 pub write: BytesMut,
101 pub subscription_channel: Option<(
102 mpsc::WeakUnboundedSender<SubscriptionEvent>,
103 mpsc::UnboundedReceiver<SubscriptionEvent>,
104 )>,
105 pub subscriptions: HashMap<Uuid, watch::Sender<Option<u64>>>,
106}
107
108impl Conn {
109 fn new(
110 cluster_ref: ActorRef<ClusterActor>,
111 caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
112 num_partitions: u16,
113 cache_capacity_bytes: usize,
114 stream: TcpStream,
115 shutdown: CancellationToken,
116 ) -> Self {
117 let read = BytesMut::new();
118 let write = BytesMut::new();
119
120 Conn {
121 cluster_ref,
122 caches,
123 num_partitions,
124 cache_capacity_bytes,
125 stream,
126 shutdown,
127 read,
128 write,
129 subscription_channel: None,
130 subscriptions: HashMap::new(),
131 }
132 }
133
134 async fn run(mut self) -> io::Result<()> {
135 loop {
136 match &mut self.subscription_channel {
137 Some((_, rx)) => {
138 tokio::select! {
139 res = self.stream.read_buf(&mut self.read) => {
140 match res {
141 Ok(bytes_read) => {
142 if bytes_read == 0 && self.read.is_empty() {
143 self.cleanup_subscriptions();
145 return Ok(());
146 }
147
148 while let Some((frame, _, _)) =
150 decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
151 {
152 let response = self.handle_request(frame).await?;
153 if let Some(resp) = response {
154 resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
155 .map_err(io::Error::other)?;
156
157 self.stream.write_all(&self.write).await?;
158 self.stream.flush().await?;
159 self.write.clear();
160 }
161 }
162 }
163 Err(err) => return Err(err),
164 }
165 }
166 msg = rx.recv() => {
167 match msg {
168 Some(SubscriptionEvent::Record { subscription_id, cursor, record }) => self.send_subscription_event(subscription_id, cursor, record).await?,
169 Some(SubscriptionEvent::Error { subscription_id, error }) => {
170 warn!(%subscription_id, "subscription error: {error}");
171 }
172 Some(SubscriptionEvent::Closed { subscription_id }) => {
173 debug!(
174 subscription_id = %subscription_id,
175 "closed subscription"
176 );
177 self.subscriptions.remove(&subscription_id);
178 if self.subscriptions.is_empty() {
179 self.cleanup_subscriptions();
180 }
181 }
182 None => self.cleanup_subscriptions(),
183 }
184 }
185 _ = self.shutdown.cancelled() => {
186 rx.close();
187 return self.stream.shutdown().await;
188 }
189 }
190 }
191 None => {
192 tokio::select! {
193 res = self.stream.read_buf(&mut self.read) => {
194 let bytes_read = res?;
196 if bytes_read == 0 && self.read.is_empty() {
197 return Ok(());
198 }
199
200 while let Some((frame, _, _)) =
202 decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
203 {
204 let response = self.handle_request(frame).await?;
205 if let Some(resp) = response {
206 resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
207 .map_err(io::Error::other)?;
208
209 self.stream.write_all(&self.write).await?;
210 self.stream.flush().await?;
211 self.write.clear();
212 }
213 }
214 }
215 _ = self.shutdown.cancelled() => {
216 return self.stream.shutdown().await;
217 }
218 }
219 }
220 }
221 }
222 }
223
224 fn cleanup_subscriptions(&mut self) {
225 self.subscriptions.clear();
226 self.subscription_channel = None;
227 }
228
229 async fn send_subscription_event(
230 &mut self,
231 subscription_id: Uuid,
232 cursor: u64,
233 record: EventRecord,
234 ) -> io::Result<()> {
235 resp3::encode::complete::extend_encode(
236 &mut self.write,
237 &BytesFrame::Push {
238 data: vec![
239 simple_str("message"),
240 simple_str(subscription_id.to_string()),
241 number(cursor as i64),
242 encode_event(record),
243 ],
244 attributes: None,
245 },
246 false,
247 )
248 .map_err(io::Error::other)?;
249
250 self.stream.write_all(&self.write).await?;
251 self.stream.flush().await?;
252 self.write.clear();
253
254 Ok(())
255 }
256
257 async fn handle_request(&mut self, frame: BytesFrame) -> Result<Option<BytesFrame>, io::Error> {
258 match frame {
259 BytesFrame::Array { data, .. } => {
260 if data.is_empty() {
261 return Ok(Some(BytesFrame::SimpleError {
262 data: "empty command".into(),
263 attributes: None,
264 }));
265 }
266
267 let cmd = match Command::try_from(&data[0]) {
268 Ok(cmd) => cmd,
269 Err(err) => {
270 return Ok(Some(BytesFrame::SimpleError {
271 data: err.into(),
272 attributes: None,
273 }));
274 }
275 };
276 let args = &data[1..];
277 cmd.handle(args, self).await
278 }
279 _ => Ok(Some(BytesFrame::SimpleError {
280 data: "expected array command".into(),
281 attributes: None,
282 })),
283 }
284 }
285}