Skip to main content

reddb_client/connector/
redwire.rs

1//! RedWire native client.
2//!
3//! Speaks the binary frame protocol defined by ADR 0001 directly
4//! over TCP — no engine, no tonic, no HTTP. The codec and frame
5//! types come from [`reddb_wire::redwire`]; this module handles
6//! the connect + handshake + per-request frame exchange.
7//!
8//! Auth: anonymous and bearer for now. SCRAM-SHA-256, OAuth/JWT,
9//! and mTLS are tracked as follow-up work in the parent issue.
10//!
11//! TLS (`reds://`) is not yet implemented in this slice — see the
12//! `TlsNotImplemented` error variant. The plain TCP path is the
13//! primary deliverable.
14
15use std::fmt;
16
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::TcpStream;
19
20use reddb_wire::redwire::{
21    decode_frame, encode_frame, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_KNOWN_MINOR_VERSION,
22    REDWIRE_MAGIC,
23};
24
25#[derive(Debug, Clone)]
26pub enum Auth {
27    Anonymous,
28    Bearer(String),
29}
30
31#[derive(Debug)]
32pub enum RedWireError {
33    Network(String),
34    Protocol(String),
35    AuthRefused(String),
36    Engine(String),
37    TlsNotImplemented,
38}
39
40impl fmt::Display for RedWireError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Self::Network(m) => write!(f, "network: {m}"),
44            Self::Protocol(m) => write!(f, "protocol: {m}"),
45            Self::AuthRefused(m) => write!(f, "auth refused: {m}"),
46            Self::Engine(m) => write!(f, "engine error: {m}"),
47            Self::TlsNotImplemented => write!(
48                f,
49                "RedWire-over-TLS (reds://) is not yet wired through red_client; \
50                 use red:// (plain) or the full `red` binary for now"
51            ),
52        }
53    }
54}
55
56impl std::error::Error for RedWireError {}
57
58type Result<T> = std::result::Result<T, RedWireError>;
59
60#[derive(Debug)]
61pub struct RedWireClient {
62    stream: TcpStream,
63    next_corr: u64,
64    #[allow(dead_code)]
65    session_id: String,
66}
67
68impl RedWireClient {
69    pub async fn connect(host: &str, port: u16, tls: bool, auth: Auth) -> Result<Self> {
70        if tls {
71            return Err(RedWireError::TlsNotImplemented);
72        }
73        let addr = format!("{host}:{port}");
74        let mut stream = TcpStream::connect(&addr)
75            .await
76            .map_err(|e| RedWireError::Network(format!("{addr}: {e}")))?;
77        // Magic discriminator + supported minor version.
78        stream
79            .write_all(&[REDWIRE_MAGIC, MAX_KNOWN_MINOR_VERSION])
80            .await
81            .map_err(|e| RedWireError::Network(e.to_string()))?;
82        let session_id = handshake(&mut stream, &auth).await?;
83        Ok(Self {
84            stream,
85            next_corr: 1,
86            session_id,
87        })
88    }
89
90    pub async fn query(&mut self, sql: &str) -> Result<String> {
91        let corr = self.next_corr_id();
92        let frame = Frame::new(MessageKind::Query, corr, sql.as_bytes().to_vec());
93        self.stream
94            .write_all(&encode_frame(&frame))
95            .await
96            .map_err(|e| RedWireError::Network(e.to_string()))?;
97        let resp = read_frame(&mut self.stream).await?;
98        match resp.kind {
99            MessageKind::Result => Ok(String::from_utf8_lossy(&resp.payload).to_string()),
100            MessageKind::Error => Err(RedWireError::Engine(
101                String::from_utf8_lossy(&resp.payload).to_string(),
102            )),
103            other => Err(RedWireError::Protocol(format!(
104                "expected Result/Error, got {other:?}"
105            ))),
106        }
107    }
108
109    fn next_corr_id(&mut self) -> u64 {
110        let n = self.next_corr;
111        self.next_corr = self.next_corr.wrapping_add(1);
112        n
113    }
114}
115
116async fn handshake(stream: &mut TcpStream, auth: &Auth) -> Result<String> {
117    let methods: Vec<&str> = match auth {
118        Auth::Bearer(_) => vec!["bearer"],
119        Auth::Anonymous => vec!["anonymous", "bearer"],
120    };
121    let mut hello_obj = serde_json::Map::new();
122    hello_obj.insert(
123        "versions".into(),
124        serde_json::Value::Array(vec![serde_json::Value::Number(serde_json::Number::from(
125            MAX_KNOWN_MINOR_VERSION,
126        ))]),
127    );
128    hello_obj.insert(
129        "auth_methods".into(),
130        serde_json::Value::Array(
131            methods
132                .iter()
133                .map(|s| serde_json::Value::String((*s).to_string()))
134                .collect(),
135        ),
136    );
137    hello_obj.insert(
138        "features".into(),
139        serde_json::Value::Number(serde_json::Number::from(0u32)),
140    );
141    let hello_bytes = serde_json::to_vec(&serde_json::Value::Object(hello_obj))
142        .map_err(|e| RedWireError::Protocol(format!("encode hello: {e}")))?;
143    let hello = Frame::new(MessageKind::Hello, 1, hello_bytes);
144    stream
145        .write_all(&encode_frame(&hello))
146        .await
147        .map_err(|e| RedWireError::Network(e.to_string()))?;
148
149    let ack = read_frame(stream).await?;
150    let chosen = match ack.kind {
151        MessageKind::HelloAck => parse_chosen_auth(&ack.payload)?,
152        MessageKind::AuthFail => {
153            return Err(RedWireError::AuthRefused(
154                parse_reason(&ack.payload).unwrap_or_else(|| "AuthFail at HelloAck".into()),
155            ));
156        }
157        other => {
158            return Err(RedWireError::Protocol(format!(
159                "expected HelloAck, got {other:?}"
160            )));
161        }
162    };
163
164    let resp_payload = match (chosen.as_str(), auth) {
165        ("anonymous", _) => Vec::new(),
166        ("bearer", Auth::Bearer(token)) => {
167            let mut obj = serde_json::Map::new();
168            obj.insert("token".into(), serde_json::Value::String(token.clone()));
169            serde_json::to_vec(&serde_json::Value::Object(obj))
170                .map_err(|e| RedWireError::Protocol(format!("encode auth: {e}")))?
171        }
172        ("bearer", Auth::Anonymous) => {
173            return Err(RedWireError::AuthRefused(
174                "server demands bearer auth but no token was supplied".into(),
175            ));
176        }
177        (other, _) => {
178            return Err(RedWireError::Protocol(format!(
179                "server picked unsupported auth method: {other}"
180            )));
181        }
182    };
183    let resp = Frame::new(MessageKind::AuthResponse, 2, resp_payload);
184    stream
185        .write_all(&encode_frame(&resp))
186        .await
187        .map_err(|e| RedWireError::Network(e.to_string()))?;
188
189    let final_frame = read_frame(stream).await?;
190    match final_frame.kind {
191        MessageKind::AuthOk => {
192            let parsed: serde_json::Value = serde_json::from_slice(&final_frame.payload)
193                .map_err(|e| RedWireError::Protocol(format!("decode auth_ok: {e}")))?;
194            let session_id = parsed
195                .as_object()
196                .and_then(|o| o.get("session_id"))
197                .and_then(|v| v.as_str())
198                .unwrap_or("")
199                .to_string();
200            Ok(session_id)
201        }
202        MessageKind::AuthFail => Err(RedWireError::AuthRefused(
203            parse_reason(&final_frame.payload).unwrap_or_else(|| "auth refused".into()),
204        )),
205        other => Err(RedWireError::Protocol(format!(
206            "expected AuthOk/AuthFail, got {other:?}"
207        ))),
208    }
209}
210
211async fn read_frame(stream: &mut TcpStream) -> Result<Frame> {
212    let mut header = [0u8; FRAME_HEADER_SIZE];
213    stream
214        .read_exact(&mut header)
215        .await
216        .map_err(|e| RedWireError::Network(e.to_string()))?;
217    let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
218    let mut buf = vec![0u8; length];
219    buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
220    if length > FRAME_HEADER_SIZE {
221        stream
222            .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
223            .await
224            .map_err(|e| RedWireError::Network(e.to_string()))?;
225    }
226    let (frame, _) =
227        decode_frame(&buf).map_err(|e| RedWireError::Protocol(format!("decode: {e}")))?;
228    Ok(frame)
229}
230
231fn parse_chosen_auth(payload: &[u8]) -> Result<String> {
232    let v: serde_json::Value = serde_json::from_slice(payload)
233        .map_err(|e| RedWireError::Protocol(format!("decode hello_ack: {e}")))?;
234    v.as_object()
235        .and_then(|o| o.get("auth"))
236        .and_then(|x| x.as_str())
237        .map(String::from)
238        .ok_or_else(|| RedWireError::Protocol("hello_ack missing auth field".into()))
239}
240
241fn parse_reason(payload: &[u8]) -> Option<String> {
242    let v: serde_json::Value = serde_json::from_slice(payload).ok()?;
243    v.as_object()?.get("reason")?.as_str().map(String::from)
244}