hrana_client_proto/
lib.rs

1//! Messages in the Hrana protocol.
2//!
3//! Please consult the Hrana specification in the `docs/` directory for more information.
4use std::fmt;
5
6use serde::{Deserialize, Serialize};
7
8pub mod pipeline;
9
10#[derive(Serialize, Debug)]
11#[serde(tag = "type", rename_all = "snake_case")]
12pub enum ClientMsg {
13    Hello { jwt: Option<String> },
14    Request { request_id: i32, request: Request },
15}
16
17#[derive(Deserialize, Debug)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum ServerMsg {
20    HelloOk {},
21    HelloError { error: Error },
22    ResponseOk { request_id: i32, response: Response },
23    ResponseError { request_id: i32, error: Error },
24}
25
26#[derive(Serialize, Debug)]
27#[serde(tag = "type", rename_all = "snake_case")]
28pub enum Request {
29    OpenStream(OpenStreamReq),
30    CloseStream(CloseStreamReq),
31    Execute(ExecuteReq),
32    Batch(BatchReq),
33}
34
35#[derive(Deserialize, Debug)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum Response {
38    OpenStream(OpenStreamResp),
39    CloseStream(CloseStreamResp),
40    Execute(ExecuteResp),
41    Batch(BatchResp),
42}
43
44#[derive(Serialize, Debug)]
45pub struct OpenStreamReq {
46    pub stream_id: i32,
47}
48
49#[derive(Deserialize, Debug)]
50pub struct OpenStreamResp {}
51
52#[derive(Serialize, Debug)]
53pub struct CloseStreamReq {
54    pub stream_id: i32,
55}
56
57#[derive(Deserialize, Debug)]
58pub struct CloseStreamResp {}
59
60#[derive(Serialize, Debug)]
61pub struct ExecuteReq {
62    pub stream_id: i32,
63    pub stmt: Stmt,
64}
65
66#[derive(Deserialize, Debug)]
67pub struct ExecuteResp {
68    pub result: StmtResult,
69}
70
71#[derive(Serialize, Debug)]
72pub struct Stmt {
73    pub sql: String,
74    #[serde(default)]
75    pub args: Vec<Value>,
76    #[serde(default)]
77    pub named_args: Vec<NamedArg>,
78    pub want_rows: bool,
79}
80
81impl Stmt {
82    pub fn new(sql: impl Into<String>, want_rows: bool) -> Self {
83        let sql = sql.into();
84        Self {
85            sql,
86            want_rows,
87            named_args: Vec::new(),
88            args: Vec::new(),
89        }
90    }
91
92    pub fn bind(&mut self, val: Value) {
93        self.args.push(val);
94    }
95
96    pub fn bind_named(&mut self, name: String, value: Value) {
97        self.named_args.push(NamedArg { name, value });
98    }
99}
100
101#[derive(Serialize, Debug)]
102pub struct NamedArg {
103    pub name: String,
104    pub value: Value,
105}
106
107#[derive(Deserialize, Clone, Debug)]
108pub struct StmtResult {
109    pub cols: Vec<Col>,
110    pub rows: Vec<Vec<Value>>,
111    pub affected_row_count: u64,
112    #[serde(with = "option_i64_as_str")]
113    pub last_insert_rowid: Option<i64>,
114}
115
116#[derive(Deserialize, Clone, Debug)]
117pub struct Col {
118    pub name: Option<String>,
119}
120
121#[derive(Serialize, Deserialize, Clone, Debug)]
122#[serde(tag = "type", rename_all = "snake_case")]
123pub enum Value {
124    Null,
125    Integer {
126        #[serde(with = "i64_as_str")]
127        value: i64,
128    },
129    Float {
130        value: f64,
131    },
132    Text {
133        value: String,
134    },
135    Blob {
136        #[serde(with = "bytes_as_base64", rename = "base64")]
137        value: Vec<u8>,
138    },
139}
140
141#[derive(Serialize, Debug)]
142pub struct BatchReq {
143    pub stream_id: i32,
144    pub batch: Batch,
145}
146
147#[derive(Serialize, Debug, Default)]
148pub struct Batch {
149    steps: Vec<BatchStep>,
150}
151
152impl Batch {
153    pub fn new() -> Self {
154        Self { steps: Vec::new() }
155    }
156
157    pub fn step(&mut self, condition: Option<BatchCond>, stmt: Stmt) {
158        self.steps.push(BatchStep { condition, stmt });
159    }
160}
161
162#[derive(Serialize, Debug)]
163pub struct BatchStep {
164    condition: Option<BatchCond>,
165    stmt: Stmt,
166}
167
168#[derive(Serialize, Debug)]
169pub enum BatchCond {
170    Ok { step: i32 },
171    Error { step: i32 },
172    Not { cond: Box<BatchCond> },
173    And { conds: Vec<BatchCond> },
174    Or { conds: Vec<BatchCond> },
175}
176
177#[derive(Deserialize, Debug)]
178pub struct BatchResp {
179    pub result: BatchResult,
180}
181
182#[derive(Deserialize, Debug)]
183pub struct BatchResult {
184    pub step_results: Vec<Option<StmtResult>>,
185    pub step_errors: Vec<Option<Error>>,
186}
187
188impl<T> From<Option<T>> for Value
189where
190    T: Into<Value>,
191{
192    fn from(value: Option<T>) -> Self {
193        match value {
194            None => Self::Null,
195            Some(t) => t.into(),
196        }
197    }
198}
199
200#[derive(Deserialize, Debug, Clone)]
201pub struct Error {
202    pub message: String,
203}
204
205impl fmt::Display for Error {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        f.write_str(&self.message)
208    }
209}
210
211impl std::error::Error for Error {}
212
213mod i64_as_str {
214    use serde::{de, ser};
215    use serde::{de::Error as _, Serialize as _};
216
217    pub fn serialize<S: ser::Serializer>(value: &i64, ser: S) -> Result<S::Ok, S::Error> {
218        value.to_string().serialize(ser)
219    }
220
221    pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<i64, D::Error> {
222        let str_value = <&'de str as de::Deserialize>::deserialize(de)?;
223        str_value.parse().map_err(|_| {
224            D::Error::invalid_value(
225                de::Unexpected::Str(str_value),
226                &"decimal integer as a string",
227            )
228        })
229    }
230}
231
232mod option_i64_as_str {
233    use serde::{de, de::Error as _};
234
235    pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<Option<i64>, D::Error> {
236        let str_value = <Option<&'de str> as de::Deserialize>::deserialize(de)?;
237        str_value
238            .map(|s| {
239                s.parse().map_err(|_| {
240                    D::Error::invalid_value(de::Unexpected::Str(s), &"decimal integer as a string")
241                })
242            })
243            .transpose()
244    }
245}
246
247mod bytes_as_base64 {
248    use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _};
249    use serde::{de, ser};
250    use serde::{de::Error as _, Serialize as _};
251
252    pub fn serialize<S: ser::Serializer>(value: &Vec<u8>, ser: S) -> Result<S::Ok, S::Error> {
253        STANDARD_NO_PAD.encode(value).serialize(ser)
254    }
255
256    pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<Vec<u8>, D::Error> {
257        let str_value = <&'de str as de::Deserialize>::deserialize(de)?;
258        STANDARD_NO_PAD
259            .decode(str_value.trim_end_matches('='))
260            .map_err(|_| {
261                D::Error::invalid_value(
262                    de::Unexpected::Str(str_value),
263                    &"binary data encoded as base64",
264                )
265            })
266    }
267}
268
269impl std::fmt::Display for Value {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        match self {
272            Value::Null => write!(f, "null"),
273            Value::Integer { value: n } => write!(f, "{n}"),
274            Value::Float { value: d } => write!(f, "{d}"),
275            Value::Text { value: s } => write!(f, "{}", serde_json::json!(s)),
276            Value::Blob { value: b } => {
277                use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
278                let b = BASE64_STANDARD_NO_PAD.encode(b);
279                write!(f, "{{\"base64\": {b}}}")
280            }
281        }
282    }
283}
284
285impl From<()> for Value {
286    fn from(_: ()) -> Value {
287        Value::Null
288    }
289}
290
291macro_rules! impl_from_value {
292    ($typename: ty, $variant: ident) => {
293        impl From<$typename> for Value {
294            fn from(t: $typename) -> Value {
295                Value::$variant { value: t.into() }
296            }
297        }
298    };
299}
300
301impl_from_value!(String, Text);
302impl_from_value!(&String, Text);
303impl_from_value!(&str, Text);
304
305impl_from_value!(i8, Integer);
306impl_from_value!(i16, Integer);
307impl_from_value!(i32, Integer);
308impl_from_value!(i64, Integer);
309
310impl_from_value!(u8, Integer);
311impl_from_value!(u16, Integer);
312impl_from_value!(u32, Integer);
313
314// rust doesn't like usize.into() for i64, so do it manually.
315impl From<usize> for Value {
316    fn from(t: usize) -> Value {
317        Value::Integer { value: t as _ }
318    }
319}
320
321impl From<isize> for Value {
322    fn from(t: isize) -> Value {
323        Value::Integer { value: t as _ }
324    }
325}
326
327impl_from_value!(f32, Float);
328impl_from_value!(f64, Float);
329
330impl_from_value!(Vec<u8>, Blob);
331
332macro_rules! impl_value_try_from_core {
333    ($variant: ident, $typename: ty) => {
334        impl TryFrom<Value> for $typename {
335            type Error = String;
336            fn try_from(v: Value) -> Result<$typename, Self::Error> {
337                match v {
338                    Value::$variant { value: v } => v.try_into().map_err(|e| format!("{e}")),
339                    other => Err(format!(
340                        "cannot transform {other:?} to {}",
341                        stringify!($variant)
342                    )),
343                }
344            }
345        }
346    };
347}
348
349macro_rules! impl_value_try_from_pod {
350    ($variant: ident, $typename: ty) => {
351        impl_value_try_from_core!($variant, $typename);
352
353        impl TryFrom<&Value> for $typename {
354            type Error = String;
355            fn try_from(v: &Value) -> Result<$typename, Self::Error> {
356                match v {
357                    Value::$variant { value: v } => (*v).try_into().map_err(|e| format!("{e}")),
358                    other => Err(format!(
359                        "cannot transform {other:?} to {}",
360                        stringify!($variant)
361                    )),
362                }
363            }
364        }
365    };
366}
367
368macro_rules! impl_value_try_from_ref {
369    ($variant: ident, $typename: ty, $reftype: ty) => {
370        impl_value_try_from_core!($variant, $typename);
371
372        impl<'a> TryFrom<&'a Value> for &'a $reftype {
373            type Error = String;
374            fn try_from(v: &'a Value) -> Result<&'a $reftype, Self::Error> {
375                match v {
376                    Value::$variant { value: v } => Ok(v),
377                    other => Err(format!(
378                        "cannot transform {other:?} to {}",
379                        stringify!($variant)
380                    )),
381                }
382            }
383        }
384    };
385}
386
387impl_value_try_from_pod!(Integer, i8);
388impl_value_try_from_pod!(Integer, i16);
389impl_value_try_from_pod!(Integer, i32);
390impl_value_try_from_pod!(Integer, i64);
391impl_value_try_from_pod!(Integer, u8);
392impl_value_try_from_pod!(Integer, u16);
393impl_value_try_from_pod!(Integer, u32);
394impl_value_try_from_pod!(Integer, u64);
395impl_value_try_from_pod!(Integer, usize);
396impl_value_try_from_pod!(Integer, isize);
397impl_value_try_from_pod!(Float, f64);
398
399impl_value_try_from_ref!(Text, String, str);
400impl_value_try_from_ref!(Blob, Vec<u8>, [u8]);