reddb_client/connector/
redwire.rs1use std::fmt;
16
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::TcpStream;
19
20use reddb_wire::redwire::{
21 decode_frame, encode_frame, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_KNOWN_MINOR_VERSION,
22 REDWIRE_MAGIC,
23};
24
25#[derive(Debug, Clone)]
26pub enum Auth {
27 Anonymous,
28 Bearer(String),
29}
30
31#[derive(Debug)]
32pub enum RedWireError {
33 Network(String),
34 Protocol(String),
35 AuthRefused(String),
36 Engine(String),
37 TlsNotImplemented,
38}
39
40impl fmt::Display for RedWireError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Self::Network(m) => write!(f, "network: {m}"),
44 Self::Protocol(m) => write!(f, "protocol: {m}"),
45 Self::AuthRefused(m) => write!(f, "auth refused: {m}"),
46 Self::Engine(m) => write!(f, "engine error: {m}"),
47 Self::TlsNotImplemented => write!(
48 f,
49 "RedWire-over-TLS (reds://) is not yet wired through red_client; \
50 use red:// (plain) or the full `red` binary for now"
51 ),
52 }
53 }
54}
55
56impl std::error::Error for RedWireError {}
57
58type Result<T> = std::result::Result<T, RedWireError>;
59
60#[derive(Debug)]
61pub struct RedWireClient {
62 stream: TcpStream,
63 next_corr: u64,
64 #[allow(dead_code)]
65 session_id: String,
66}
67
68impl RedWireClient {
69 pub async fn connect(host: &str, port: u16, tls: bool, auth: Auth) -> Result<Self> {
70 if tls {
71 return Err(RedWireError::TlsNotImplemented);
72 }
73 let addr = format!("{host}:{port}");
74 let mut stream = TcpStream::connect(&addr)
75 .await
76 .map_err(|e| RedWireError::Network(format!("{addr}: {e}")))?;
77 stream
79 .write_all(&[REDWIRE_MAGIC, MAX_KNOWN_MINOR_VERSION])
80 .await
81 .map_err(|e| RedWireError::Network(e.to_string()))?;
82 let session_id = handshake(&mut stream, &auth).await?;
83 Ok(Self {
84 stream,
85 next_corr: 1,
86 session_id,
87 })
88 }
89
90 pub async fn query(&mut self, sql: &str) -> Result<String> {
91 let corr = self.next_corr_id();
92 let frame = Frame::new(MessageKind::Query, corr, sql.as_bytes().to_vec());
93 self.stream
94 .write_all(&encode_frame(&frame))
95 .await
96 .map_err(|e| RedWireError::Network(e.to_string()))?;
97 let resp = read_frame(&mut self.stream).await?;
98 match resp.kind {
99 MessageKind::Result => Ok(String::from_utf8_lossy(&resp.payload).to_string()),
100 MessageKind::Error => Err(RedWireError::Engine(
101 String::from_utf8_lossy(&resp.payload).to_string(),
102 )),
103 other => Err(RedWireError::Protocol(format!(
104 "expected Result/Error, got {other:?}"
105 ))),
106 }
107 }
108
109 fn next_corr_id(&mut self) -> u64 {
110 let n = self.next_corr;
111 self.next_corr = self.next_corr.wrapping_add(1);
112 n
113 }
114}
115
116async fn handshake(stream: &mut TcpStream, auth: &Auth) -> Result<String> {
117 let methods: Vec<&str> = match auth {
118 Auth::Bearer(_) => vec!["bearer"],
119 Auth::Anonymous => vec!["anonymous", "bearer"],
120 };
121 let mut hello_obj = serde_json::Map::new();
122 hello_obj.insert(
123 "versions".into(),
124 serde_json::Value::Array(vec![serde_json::Value::Number(serde_json::Number::from(
125 MAX_KNOWN_MINOR_VERSION,
126 ))]),
127 );
128 hello_obj.insert(
129 "auth_methods".into(),
130 serde_json::Value::Array(
131 methods
132 .iter()
133 .map(|s| serde_json::Value::String((*s).to_string()))
134 .collect(),
135 ),
136 );
137 hello_obj.insert(
138 "features".into(),
139 serde_json::Value::Number(serde_json::Number::from(0u32)),
140 );
141 let hello_bytes = serde_json::to_vec(&serde_json::Value::Object(hello_obj))
142 .map_err(|e| RedWireError::Protocol(format!("encode hello: {e}")))?;
143 let hello = Frame::new(MessageKind::Hello, 1, hello_bytes);
144 stream
145 .write_all(&encode_frame(&hello))
146 .await
147 .map_err(|e| RedWireError::Network(e.to_string()))?;
148
149 let ack = read_frame(stream).await?;
150 let chosen = match ack.kind {
151 MessageKind::HelloAck => parse_chosen_auth(&ack.payload)?,
152 MessageKind::AuthFail => {
153 return Err(RedWireError::AuthRefused(
154 parse_reason(&ack.payload).unwrap_or_else(|| "AuthFail at HelloAck".into()),
155 ));
156 }
157 other => {
158 return Err(RedWireError::Protocol(format!(
159 "expected HelloAck, got {other:?}"
160 )));
161 }
162 };
163
164 let resp_payload = match (chosen.as_str(), auth) {
165 ("anonymous", _) => Vec::new(),
166 ("bearer", Auth::Bearer(token)) => {
167 let mut obj = serde_json::Map::new();
168 obj.insert("token".into(), serde_json::Value::String(token.clone()));
169 serde_json::to_vec(&serde_json::Value::Object(obj))
170 .map_err(|e| RedWireError::Protocol(format!("encode auth: {e}")))?
171 }
172 ("bearer", Auth::Anonymous) => {
173 return Err(RedWireError::AuthRefused(
174 "server demands bearer auth but no token was supplied".into(),
175 ));
176 }
177 (other, _) => {
178 return Err(RedWireError::Protocol(format!(
179 "server picked unsupported auth method: {other}"
180 )));
181 }
182 };
183 let resp = Frame::new(MessageKind::AuthResponse, 2, resp_payload);
184 stream
185 .write_all(&encode_frame(&resp))
186 .await
187 .map_err(|e| RedWireError::Network(e.to_string()))?;
188
189 let final_frame = read_frame(stream).await?;
190 match final_frame.kind {
191 MessageKind::AuthOk => {
192 let parsed: serde_json::Value = serde_json::from_slice(&final_frame.payload)
193 .map_err(|e| RedWireError::Protocol(format!("decode auth_ok: {e}")))?;
194 let session_id = parsed
195 .as_object()
196 .and_then(|o| o.get("session_id"))
197 .and_then(|v| v.as_str())
198 .unwrap_or("")
199 .to_string();
200 Ok(session_id)
201 }
202 MessageKind::AuthFail => Err(RedWireError::AuthRefused(
203 parse_reason(&final_frame.payload).unwrap_or_else(|| "auth refused".into()),
204 )),
205 other => Err(RedWireError::Protocol(format!(
206 "expected AuthOk/AuthFail, got {other:?}"
207 ))),
208 }
209}
210
211async fn read_frame(stream: &mut TcpStream) -> Result<Frame> {
212 let mut header = [0u8; FRAME_HEADER_SIZE];
213 stream
214 .read_exact(&mut header)
215 .await
216 .map_err(|e| RedWireError::Network(e.to_string()))?;
217 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
218 let mut buf = vec![0u8; length];
219 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
220 if length > FRAME_HEADER_SIZE {
221 stream
222 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
223 .await
224 .map_err(|e| RedWireError::Network(e.to_string()))?;
225 }
226 let (frame, _) =
227 decode_frame(&buf).map_err(|e| RedWireError::Protocol(format!("decode: {e}")))?;
228 Ok(frame)
229}
230
231fn parse_chosen_auth(payload: &[u8]) -> Result<String> {
232 let v: serde_json::Value = serde_json::from_slice(payload)
233 .map_err(|e| RedWireError::Protocol(format!("decode hello_ack: {e}")))?;
234 v.as_object()
235 .and_then(|o| o.get("auth"))
236 .and_then(|x| x.as_str())
237 .map(String::from)
238 .ok_or_else(|| RedWireError::Protocol("hello_ack missing auth field".into()))
239}
240
241fn parse_reason(payload: &[u8]) -> Option<String> {
242 let v: serde_json::Value = serde_json::from_slice(payload).ok()?;
243 v.as_object()?.get("reason")?.as_str().map(String::from)
244}