libsql_hrana/
proto.rs

1//! Structures in Hrana that are common for WebSockets and HTTP.
2
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7#[derive(Serialize, Deserialize, prost::Message)]
8pub struct PipelineReqBody {
9    #[prost(string, optional, tag = "1")]
10    pub baton: Option<String>,
11    #[prost(message, repeated, tag = "2")]
12    pub requests: Vec<StreamRequest>,
13}
14
15#[derive(Serialize, Deserialize, prost::Message)]
16pub struct PipelineRespBody {
17    #[prost(string, optional, tag = "1")]
18    pub baton: Option<String>,
19    #[prost(string, optional, tag = "2")]
20    pub base_url: Option<String>,
21    #[prost(message, repeated, tag = "3")]
22    pub results: Vec<StreamResult>,
23}
24
25#[derive(Serialize, Deserialize, Default, Debug)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum StreamResult {
28    #[default]
29    None,
30    Ok {
31        response: StreamResponse,
32    },
33    Error {
34        error: Error,
35    },
36}
37
38#[derive(Serialize, Deserialize, prost::Message)]
39pub struct CursorReqBody {
40    #[prost(string, optional, tag = "1")]
41    pub baton: Option<String>,
42    #[prost(message, required, tag = "2")]
43    pub batch: Batch,
44}
45
46#[derive(Serialize, Deserialize, prost::Message)]
47pub struct CursorRespBody {
48    #[prost(string, optional, tag = "1")]
49    pub baton: Option<String>,
50    #[prost(string, optional, tag = "2")]
51    pub base_url: Option<String>,
52}
53
54#[derive(Serialize, Deserialize, Debug, Default)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum StreamRequest {
57    #[serde(skip_deserializing)]
58    #[default]
59    None,
60    Close(CloseStreamReq),
61    Execute(ExecuteStreamReq),
62    Batch(BatchStreamReq),
63    Sequence(SequenceStreamReq),
64    Describe(DescribeStreamReq),
65    StoreSql(StoreSqlStreamReq),
66    CloseSql(CloseSqlStreamReq),
67    GetAutocommit(GetAutocommitStreamReq),
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum StreamResponse {
73    Close(CloseStreamResp),
74    Execute(ExecuteStreamResp),
75    Batch(BatchStreamResp),
76    Sequence(SequenceStreamResp),
77    Describe(DescribeStreamResp),
78    StoreSql(StoreSqlStreamResp),
79    CloseSql(CloseSqlStreamResp),
80    GetAutocommit(GetAutocommitStreamResp),
81}
82
83#[derive(Serialize, Deserialize, prost::Message)]
84pub struct CloseStreamReq {}
85
86#[derive(Serialize, Deserialize, prost::Message)]
87pub struct CloseStreamResp {}
88
89#[derive(Serialize, Deserialize, prost::Message)]
90pub struct ExecuteStreamReq {
91    #[prost(message, required, tag = "1")]
92    pub stmt: Stmt,
93}
94
95#[derive(Serialize, Deserialize, prost::Message)]
96pub struct ExecuteStreamResp {
97    #[prost(message, required, tag = "1")]
98    pub result: StmtResult,
99}
100
101#[derive(Serialize, Deserialize, prost::Message)]
102pub struct BatchStreamReq {
103    #[prost(message, required, tag = "1")]
104    pub batch: Batch,
105}
106
107#[derive(Serialize, Deserialize, prost::Message)]
108pub struct BatchStreamResp {
109    #[prost(message, required, tag = "1")]
110    pub result: BatchResult,
111}
112
113#[derive(Serialize, Deserialize, prost::Message)]
114pub struct SequenceStreamReq {
115    #[serde(default)]
116    #[prost(string, optional, tag = "1")]
117    pub sql: Option<String>,
118    #[serde(default)]
119    #[prost(int32, optional, tag = "2")]
120    pub sql_id: Option<i32>,
121    #[serde(default, with = "option_u64_as_str")]
122    #[prost(uint64, optional, tag = "3")]
123    pub replication_index: Option<u64>,
124}
125
126#[derive(Serialize, Deserialize, prost::Message)]
127pub struct SequenceStreamResp {}
128
129#[derive(Serialize, Deserialize, prost::Message)]
130pub struct DescribeStreamReq {
131    #[serde(default)]
132    #[prost(string, optional, tag = "1")]
133    pub sql: Option<String>,
134    #[serde(default)]
135    #[prost(int32, optional, tag = "2")]
136    pub sql_id: Option<i32>,
137    #[serde(default, with = "option_u64_as_str")]
138    #[prost(uint64, optional, tag = "3")]
139    pub replication_index: Option<u64>,
140}
141
142#[derive(Serialize, Deserialize, prost::Message)]
143pub struct DescribeStreamResp {
144    #[prost(message, required, tag = "1")]
145    pub result: DescribeResult,
146}
147
148#[derive(Serialize, Deserialize, prost::Message)]
149pub struct StoreSqlStreamReq {
150    #[prost(int32, tag = "1")]
151    pub sql_id: i32,
152    #[prost(string, tag = "2")]
153    pub sql: String,
154}
155
156#[derive(Serialize, Deserialize, prost::Message)]
157pub struct StoreSqlStreamResp {}
158
159#[derive(Serialize, Deserialize, prost::Message)]
160pub struct CloseSqlStreamReq {
161    #[prost(int32, tag = "1")]
162    pub sql_id: i32,
163}
164
165#[derive(Serialize, Deserialize, prost::Message)]
166pub struct CloseSqlStreamResp {}
167
168#[derive(Serialize, Deserialize, prost::Message)]
169pub struct GetAutocommitStreamReq {}
170
171#[derive(Serialize, Deserialize, prost::Message)]
172pub struct GetAutocommitStreamResp {
173    #[prost(bool, tag = "1")]
174    pub is_autocommit: bool,
175}
176
177#[derive(Clone, Deserialize, Serialize, prost::Message)]
178pub struct Error {
179    #[prost(string, tag = "1")]
180    pub message: String,
181    #[prost(string, tag = "2")]
182    pub code: String,
183}
184
185#[derive(Clone, Deserialize, Serialize, prost::Message)]
186pub struct Stmt {
187    #[serde(default)]
188    #[prost(string, optional, tag = "1")]
189    pub sql: Option<String>,
190    #[serde(default)]
191    #[prost(int32, optional, tag = "2")]
192    pub sql_id: Option<i32>,
193    #[serde(default)]
194    #[prost(message, repeated, tag = "3")]
195    pub args: Vec<Value>,
196    #[serde(default)]
197    #[prost(message, repeated, tag = "4")]
198    pub named_args: Vec<NamedArg>,
199    #[serde(default)]
200    #[prost(bool, optional, tag = "5")]
201    pub want_rows: Option<bool>,
202    #[serde(default, with = "option_u64_as_str")]
203    #[prost(uint64, optional, tag = "6")]
204    pub replication_index: Option<u64>,
205}
206
207impl Stmt {
208    pub fn new<S: Into<String>>(sql: S, want_rows: bool) -> Self {
209        Stmt {
210            sql: Some(sql.into()),
211            sql_id: None,
212            args: vec![],
213            named_args: vec![],
214            want_rows: Some(want_rows),
215            replication_index: None,
216        }
217    }
218
219    pub fn bind(&mut self, value: Value) {
220        self.args.push(value);
221    }
222
223    pub fn bind_named(&mut self, name: String, value: Value) {
224        self.named_args.push(NamedArg { name, value });
225    }
226}
227
228#[derive(Clone, Deserialize, Serialize, prost::Message)]
229pub struct NamedArg {
230    #[prost(string, tag = "1")]
231    pub name: String,
232    #[prost(message, required, tag = "2")]
233    pub value: Value,
234}
235
236#[derive(Clone, Serialize, Deserialize, prost::Message)]
237pub struct StmtResult {
238    #[prost(message, repeated, tag = "1")]
239    pub cols: Vec<Col>,
240    #[prost(message, repeated, tag = "2")]
241    pub rows: Vec<Row>,
242    #[prost(uint64, tag = "3")]
243    pub affected_row_count: u64,
244    #[serde(with = "option_i64_as_str")]
245    #[prost(sint64, optional, tag = "4")]
246    pub last_insert_rowid: Option<i64>,
247    #[serde(default, with = "option_u64_as_str")]
248    #[prost(uint64, optional, tag = "5")]
249    pub replication_index: Option<u64>,
250    #[prost(uint64, tag = "6")]
251    #[serde(default)]
252    pub rows_read: u64,
253    #[prost(uint64, tag = "7")]
254    #[serde(default)]
255    pub rows_written: u64,
256    #[prost(double, tag = "8")]
257    #[serde(default)]
258    pub query_duration_ms: f64,
259}
260
261#[derive(Clone, Deserialize, Serialize, prost::Message)]
262pub struct Col {
263    #[prost(string, optional, tag = "1")]
264    pub name: Option<String>,
265    #[prost(string, optional, tag = "2")]
266    pub decltype: Option<String>,
267}
268
269#[derive(Clone, Deserialize, Serialize, prost::Message)]
270#[serde(transparent)]
271pub struct Row {
272    #[prost(message, repeated, tag = "1")]
273    pub values: Vec<Value>,
274}
275
276#[derive(Clone, Deserialize, Serialize, prost::Message)]
277pub struct Batch {
278    #[prost(message, repeated, tag = "1")]
279    pub steps: Vec<BatchStep>,
280    #[prost(uint64, optional, tag = "2")]
281    #[serde(default, with = "option_u64_as_str")]
282    pub replication_index: Option<u64>,
283}
284
285impl Batch {
286    pub fn single(stmt: Stmt) -> Self {
287        Batch {
288            steps: vec![BatchStep {
289                condition: None,
290                stmt,
291            }],
292            replication_index: None,
293        }
294    }
295    pub fn transactional<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
296        let mut steps = Vec::new();
297        steps.push(BatchStep {
298            condition: None,
299            stmt: Stmt::new("BEGIN TRANSACTION", false),
300        });
301        let mut count = 0u32;
302        for (step, stmt) in stmts.into_iter().enumerate() {
303            count += 1;
304            let condition = Some(BatchCond::Ok { step: step as u32 });
305            steps.push(BatchStep { condition, stmt });
306        }
307        steps.push(BatchStep {
308            condition: Some(BatchCond::Ok { step: count }),
309            stmt: Stmt::new("COMMIT", false),
310        });
311        steps.push(BatchStep {
312            condition: Some(BatchCond::Not {
313                cond: Box::new(BatchCond::Ok { step: count + 1 }),
314            }),
315            stmt: Stmt::new("ROLLBACK", false),
316        });
317        Batch {
318            steps,
319            replication_index: None,
320        }
321    }
322}
323
324impl FromIterator<Stmt> for Batch {
325    fn from_iter<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
326        let mut steps = Vec::new();
327        for (step, stmt) in stmts.into_iter().enumerate() {
328            let condition = if step > 0 {
329                Some(BatchCond::Ok {
330                    step: (step - 1) as u32,
331                })
332            } else {
333                None
334            };
335            steps.push(BatchStep { condition, stmt });
336        }
337        Batch {
338            steps,
339            replication_index: None,
340        }
341    }
342}
343
344#[derive(Clone, Deserialize, Serialize, prost::Message)]
345pub struct BatchStep {
346    #[serde(default)]
347    #[prost(message, optional, tag = "1")]
348    pub condition: Option<BatchCond>,
349    #[prost(message, required, tag = "2")]
350    pub stmt: Stmt,
351}
352
353#[derive(Clone, Deserialize, Serialize, Debug, Default)]
354pub struct BatchResult {
355    pub step_results: Vec<Option<StmtResult>>,
356    pub step_errors: Vec<Option<Error>>,
357    #[serde(default, with = "option_u64_as_str")]
358    pub replication_index: Option<u64>,
359}
360
361#[derive(Clone, Deserialize, Serialize, Debug, Default)]
362#[serde(tag = "type", rename_all = "snake_case")]
363pub enum BatchCond {
364    #[serde(skip_deserializing)]
365    #[default]
366    None,
367    Ok {
368        step: u32,
369    },
370    Error {
371        step: u32,
372    },
373    Not {
374        cond: Box<BatchCond>,
375    },
376    And(BatchCondList),
377    Or(BatchCondList),
378    IsAutocommit {},
379}
380
381#[derive(Clone, Deserialize, Serialize, prost::Message)]
382pub struct BatchCondList {
383    #[prost(message, repeated, tag = "1")]
384    pub conds: Vec<BatchCond>,
385}
386
387#[derive(Clone, Deserialize, Serialize, Debug, Default)]
388#[serde(tag = "type", rename_all = "snake_case")]
389pub enum CursorEntry {
390    #[serde(skip_deserializing)]
391    #[default]
392    None,
393    StepBegin(StepBeginEntry),
394    StepEnd(StepEndEntry),
395    StepError(StepErrorEntry),
396    Row {
397        row: Row,
398    },
399    Error {
400        error: Error,
401    },
402    ReplicationIndex {
403        replication_index: Option<u64>,
404    },
405}
406
407#[derive(Clone, Deserialize, Serialize, prost::Message)]
408pub struct StepBeginEntry {
409    #[prost(uint32, tag = "1")]
410    pub step: u32,
411    #[prost(message, repeated, tag = "2")]
412    pub cols: Vec<Col>,
413}
414
415#[derive(Clone, Deserialize, Serialize, prost::Message)]
416pub struct StepEndEntry {
417    #[prost(uint64, tag = "1")]
418    pub affected_row_count: u64,
419    #[prost(sint64, optional, tag = "2")]
420    pub last_insert_rowid: Option<i64>,
421}
422
423#[derive(Clone, Deserialize, Serialize, prost::Message)]
424pub struct StepErrorEntry {
425    #[prost(uint32, tag = "1")]
426    pub step: u32,
427    #[prost(message, required, tag = "2")]
428    pub error: Error,
429}
430
431#[derive(Clone, Deserialize, Serialize, prost::Message)]
432pub struct DescribeResult {
433    #[prost(message, repeated, tag = "1")]
434    pub params: Vec<DescribeParam>,
435    #[prost(message, repeated, tag = "2")]
436    pub cols: Vec<DescribeCol>,
437    #[prost(bool, tag = "3")]
438    pub is_explain: bool,
439    #[prost(bool, tag = "4")]
440    pub is_readonly: bool,
441}
442
443#[derive(Clone, Deserialize, Serialize, prost::Message)]
444pub struct DescribeParam {
445    #[prost(string, optional, tag = "1")]
446    pub name: Option<String>,
447}
448
449#[derive(Clone, Deserialize, Serialize, prost::Message)]
450pub struct DescribeCol {
451    #[prost(string, tag = "1")]
452    pub name: String,
453    #[prost(string, optional, tag = "2")]
454    pub decltype: Option<String>,
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize, Default)]
458#[serde(tag = "type", rename_all = "snake_case")]
459pub enum Value {
460    #[serde(skip_deserializing)]
461    #[default]
462    None,
463    Null,
464    Integer {
465        #[serde(with = "i64_as_str")]
466        value: i64,
467    },
468    Float {
469        value: f64,
470    },
471    Text {
472        value: Arc<str>,
473    },
474    Blob {
475        #[serde(with = "bytes_as_base64", rename = "base64")]
476        value: Bytes,
477    },
478}
479
480mod i64_as_str {
481    use serde::{de, ser};
482    use serde::{de::Error as _, Serialize as _};
483
484    pub fn serialize<S: ser::Serializer>(value: &i64, ser: S) -> Result<S::Ok, S::Error> {
485        value.to_string().serialize(ser)
486    }
487
488    pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<i64, D::Error> {
489        let str_value = <&'de str as de::Deserialize>::deserialize(de)?;
490        str_value.parse().map_err(|_| {
491            D::Error::invalid_value(
492                de::Unexpected::Str(str_value),
493                &"decimal integer as a string",
494            )
495        })
496    }
497}
498
499mod option_i64_as_str {
500    use serde::de::{Error, Visitor};
501    use serde::{ser, Deserializer, Serialize as _};
502
503    pub fn serialize<S: ser::Serializer>(value: &Option<i64>, ser: S) -> Result<S::Ok, S::Error> {
504        value.map(|v| v.to_string()).serialize(ser)
505    }
506
507    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<i64>, D::Error> {
508        struct V;
509
510        impl<'de> Visitor<'de> for V {
511            type Value = Option<i64>;
512
513            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
514                write!(formatter, "a string representing a signed integer, or null")
515            }
516
517            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
518            where
519                D: Deserializer<'de>,
520            {
521                deserializer.deserialize_any(V)
522            }
523
524            fn visit_none<E>(self) -> Result<Self::Value, E>
525            where
526                E: Error,
527            {
528                Ok(None)
529            }
530
531            fn visit_unit<E>(self) -> Result<Self::Value, E>
532            where
533                E: Error,
534            {
535                Ok(None)
536            }
537
538            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
539            where
540                E: Error,
541            {
542                Ok(Some(v))
543            }
544
545            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
546            where
547                E: Error,
548            {
549                v.parse().map_err(E::custom).map(Some)
550            }
551        }
552
553        d.deserialize_option(V)
554    }
555}
556
557pub mod option_u64_as_str {
558    use serde::de::Error;
559    use serde::{de::Visitor, ser, Deserializer, Serialize as _};
560
561    pub fn serialize<S: ser::Serializer>(value: &Option<u64>, ser: S) -> Result<S::Ok, S::Error> {
562        value.map(|v| v.to_string()).serialize(ser)
563    }
564
565    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<u64>, D::Error> {
566        struct V;
567
568        impl<'de> Visitor<'de> for V {
569            type Value = Option<u64>;
570
571            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
572                write!(formatter, "a string representing an integer, or null")
573            }
574
575            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
576            where
577                D: Deserializer<'de>,
578            {
579                deserializer.deserialize_any(V)
580            }
581
582            fn visit_unit<E>(self) -> Result<Self::Value, E>
583            where
584                E: Error,
585            {
586                Ok(None)
587            }
588
589            fn visit_none<E>(self) -> Result<Self::Value, E>
590            where
591                E: Error,
592            {
593                Ok(None)
594            }
595
596            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
597            where
598                E: Error,
599            {
600                Ok(Some(v))
601            }
602
603            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
604            where
605                E: Error,
606            {
607                v.parse().map_err(E::custom).map(Some)
608            }
609        }
610
611        d.deserialize_option(V)
612    }
613
614    #[cfg(test)]
615    mod test {
616        use serde::Deserialize;
617
618        #[test]
619        fn deserialize_ok() {
620            #[derive(Deserialize)]
621            struct Test {
622                #[serde(with = "super")]
623                value: Option<u64>,
624            }
625
626            let json = r#"{"value": null }"#;
627            let val: Test = serde_json::from_str(json).unwrap();
628            assert!(val.value.is_none());
629
630            let json = r#"{"value": "124" }"#;
631            let val: Test = serde_json::from_str(json).unwrap();
632            assert_eq!(val.value.unwrap(), 124);
633
634            let json = r#"{"value": 124 }"#;
635            let val: Test = serde_json::from_str(json).unwrap();
636            assert_eq!(val.value.unwrap(), 124);
637        }
638    }
639}
640
641mod bytes_as_base64 {
642    use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _};
643    use bytes::Bytes;
644    use serde::{de, ser};
645    use serde::{de::Error as _, Serialize as _};
646
647    pub fn serialize<S: ser::Serializer>(value: &Bytes, ser: S) -> Result<S::Ok, S::Error> {
648        STANDARD_NO_PAD.encode(value).serialize(ser)
649    }
650
651    pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<Bytes, D::Error> {
652        let text = <&'de str as de::Deserialize>::deserialize(de)?;
653        let text = text.trim_end_matches('=');
654        let bytes = STANDARD_NO_PAD.decode(text).map_err(|_| {
655            D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64")
656        })?;
657        Ok(Bytes::from(bytes))
658    }
659}