1use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7#[derive(Serialize, Deserialize, prost::Message)]
8pub struct PipelineReqBody {
9 #[prost(string, optional, tag = "1")]
10 pub baton: Option<String>,
11 #[prost(message, repeated, tag = "2")]
12 pub requests: Vec<StreamRequest>,
13}
14
15#[derive(Serialize, Deserialize, prost::Message)]
16pub struct PipelineRespBody {
17 #[prost(string, optional, tag = "1")]
18 pub baton: Option<String>,
19 #[prost(string, optional, tag = "2")]
20 pub base_url: Option<String>,
21 #[prost(message, repeated, tag = "3")]
22 pub results: Vec<StreamResult>,
23}
24
25#[derive(Serialize, Deserialize, Default, Debug)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum StreamResult {
28 #[default]
29 None,
30 Ok {
31 response: StreamResponse,
32 },
33 Error {
34 error: Error,
35 },
36}
37
38#[derive(Serialize, Deserialize, prost::Message)]
39pub struct CursorReqBody {
40 #[prost(string, optional, tag = "1")]
41 pub baton: Option<String>,
42 #[prost(message, required, tag = "2")]
43 pub batch: Batch,
44}
45
46#[derive(Serialize, Deserialize, prost::Message)]
47pub struct CursorRespBody {
48 #[prost(string, optional, tag = "1")]
49 pub baton: Option<String>,
50 #[prost(string, optional, tag = "2")]
51 pub base_url: Option<String>,
52}
53
54#[derive(Serialize, Deserialize, Debug, Default)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum StreamRequest {
57 #[serde(skip_deserializing)]
58 #[default]
59 None,
60 Close(CloseStreamReq),
61 Execute(ExecuteStreamReq),
62 Batch(BatchStreamReq),
63 Sequence(SequenceStreamReq),
64 Describe(DescribeStreamReq),
65 StoreSql(StoreSqlStreamReq),
66 CloseSql(CloseSqlStreamReq),
67 GetAutocommit(GetAutocommitStreamReq),
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum StreamResponse {
73 Close(CloseStreamResp),
74 Execute(ExecuteStreamResp),
75 Batch(BatchStreamResp),
76 Sequence(SequenceStreamResp),
77 Describe(DescribeStreamResp),
78 StoreSql(StoreSqlStreamResp),
79 CloseSql(CloseSqlStreamResp),
80 GetAutocommit(GetAutocommitStreamResp),
81}
82
83#[derive(Serialize, Deserialize, prost::Message)]
84pub struct CloseStreamReq {}
85
86#[derive(Serialize, Deserialize, prost::Message)]
87pub struct CloseStreamResp {}
88
89#[derive(Serialize, Deserialize, prost::Message)]
90pub struct ExecuteStreamReq {
91 #[prost(message, required, tag = "1")]
92 pub stmt: Stmt,
93}
94
95#[derive(Serialize, Deserialize, prost::Message)]
96pub struct ExecuteStreamResp {
97 #[prost(message, required, tag = "1")]
98 pub result: StmtResult,
99}
100
101#[derive(Serialize, Deserialize, prost::Message)]
102pub struct BatchStreamReq {
103 #[prost(message, required, tag = "1")]
104 pub batch: Batch,
105}
106
107#[derive(Serialize, Deserialize, prost::Message)]
108pub struct BatchStreamResp {
109 #[prost(message, required, tag = "1")]
110 pub result: BatchResult,
111}
112
113#[derive(Serialize, Deserialize, prost::Message)]
114pub struct SequenceStreamReq {
115 #[serde(default)]
116 #[prost(string, optional, tag = "1")]
117 pub sql: Option<String>,
118 #[serde(default)]
119 #[prost(int32, optional, tag = "2")]
120 pub sql_id: Option<i32>,
121 #[serde(default, with = "option_u64_as_str")]
122 #[prost(uint64, optional, tag = "3")]
123 pub replication_index: Option<u64>,
124}
125
126#[derive(Serialize, Deserialize, prost::Message)]
127pub struct SequenceStreamResp {}
128
129#[derive(Serialize, Deserialize, prost::Message)]
130pub struct DescribeStreamReq {
131 #[serde(default)]
132 #[prost(string, optional, tag = "1")]
133 pub sql: Option<String>,
134 #[serde(default)]
135 #[prost(int32, optional, tag = "2")]
136 pub sql_id: Option<i32>,
137 #[serde(default, with = "option_u64_as_str")]
138 #[prost(uint64, optional, tag = "3")]
139 pub replication_index: Option<u64>,
140}
141
142#[derive(Serialize, Deserialize, prost::Message)]
143pub struct DescribeStreamResp {
144 #[prost(message, required, tag = "1")]
145 pub result: DescribeResult,
146}
147
148#[derive(Serialize, Deserialize, prost::Message)]
149pub struct StoreSqlStreamReq {
150 #[prost(int32, tag = "1")]
151 pub sql_id: i32,
152 #[prost(string, tag = "2")]
153 pub sql: String,
154}
155
156#[derive(Serialize, Deserialize, prost::Message)]
157pub struct StoreSqlStreamResp {}
158
159#[derive(Serialize, Deserialize, prost::Message)]
160pub struct CloseSqlStreamReq {
161 #[prost(int32, tag = "1")]
162 pub sql_id: i32,
163}
164
165#[derive(Serialize, Deserialize, prost::Message)]
166pub struct CloseSqlStreamResp {}
167
168#[derive(Serialize, Deserialize, prost::Message)]
169pub struct GetAutocommitStreamReq {}
170
171#[derive(Serialize, Deserialize, prost::Message)]
172pub struct GetAutocommitStreamResp {
173 #[prost(bool, tag = "1")]
174 pub is_autocommit: bool,
175}
176
177#[derive(Clone, Deserialize, Serialize, prost::Message)]
178pub struct Error {
179 #[prost(string, tag = "1")]
180 pub message: String,
181 #[prost(string, tag = "2")]
182 pub code: String,
183}
184
185#[derive(Clone, Deserialize, Serialize, prost::Message)]
186pub struct Stmt {
187 #[serde(default)]
188 #[prost(string, optional, tag = "1")]
189 pub sql: Option<String>,
190 #[serde(default)]
191 #[prost(int32, optional, tag = "2")]
192 pub sql_id: Option<i32>,
193 #[serde(default)]
194 #[prost(message, repeated, tag = "3")]
195 pub args: Vec<Value>,
196 #[serde(default)]
197 #[prost(message, repeated, tag = "4")]
198 pub named_args: Vec<NamedArg>,
199 #[serde(default)]
200 #[prost(bool, optional, tag = "5")]
201 pub want_rows: Option<bool>,
202 #[serde(default, with = "option_u64_as_str")]
203 #[prost(uint64, optional, tag = "6")]
204 pub replication_index: Option<u64>,
205}
206
207impl Stmt {
208 pub fn new<S: Into<String>>(sql: S, want_rows: bool) -> Self {
209 Stmt {
210 sql: Some(sql.into()),
211 sql_id: None,
212 args: vec![],
213 named_args: vec![],
214 want_rows: Some(want_rows),
215 replication_index: None,
216 }
217 }
218
219 pub fn bind(&mut self, value: Value) {
220 self.args.push(value);
221 }
222
223 pub fn bind_named(&mut self, name: String, value: Value) {
224 self.named_args.push(NamedArg { name, value });
225 }
226}
227
228#[derive(Clone, Deserialize, Serialize, prost::Message)]
229pub struct NamedArg {
230 #[prost(string, tag = "1")]
231 pub name: String,
232 #[prost(message, required, tag = "2")]
233 pub value: Value,
234}
235
236#[derive(Clone, Serialize, Deserialize, prost::Message)]
237pub struct StmtResult {
238 #[prost(message, repeated, tag = "1")]
239 pub cols: Vec<Col>,
240 #[prost(message, repeated, tag = "2")]
241 pub rows: Vec<Row>,
242 #[prost(uint64, tag = "3")]
243 pub affected_row_count: u64,
244 #[serde(with = "option_i64_as_str")]
245 #[prost(sint64, optional, tag = "4")]
246 pub last_insert_rowid: Option<i64>,
247 #[serde(default, with = "option_u64_as_str")]
248 #[prost(uint64, optional, tag = "5")]
249 pub replication_index: Option<u64>,
250 #[prost(uint64, tag = "6")]
251 #[serde(default)]
252 pub rows_read: u64,
253 #[prost(uint64, tag = "7")]
254 #[serde(default)]
255 pub rows_written: u64,
256 #[prost(double, tag = "8")]
257 #[serde(default)]
258 pub query_duration_ms: f64,
259}
260
261#[derive(Clone, Deserialize, Serialize, prost::Message)]
262pub struct Col {
263 #[prost(string, optional, tag = "1")]
264 pub name: Option<String>,
265 #[prost(string, optional, tag = "2")]
266 pub decltype: Option<String>,
267}
268
269#[derive(Clone, Deserialize, Serialize, prost::Message)]
270#[serde(transparent)]
271pub struct Row {
272 #[prost(message, repeated, tag = "1")]
273 pub values: Vec<Value>,
274}
275
276#[derive(Clone, Deserialize, Serialize, prost::Message)]
277pub struct Batch {
278 #[prost(message, repeated, tag = "1")]
279 pub steps: Vec<BatchStep>,
280 #[prost(uint64, optional, tag = "2")]
281 #[serde(default, with = "option_u64_as_str")]
282 pub replication_index: Option<u64>,
283}
284
285impl Batch {
286 pub fn single(stmt: Stmt) -> Self {
287 Batch {
288 steps: vec![BatchStep {
289 condition: None,
290 stmt,
291 }],
292 replication_index: None,
293 }
294 }
295 pub fn transactional<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
296 let mut steps = Vec::new();
297 steps.push(BatchStep {
298 condition: None,
299 stmt: Stmt::new("BEGIN TRANSACTION", false),
300 });
301 let mut count = 0u32;
302 for (step, stmt) in stmts.into_iter().enumerate() {
303 count += 1;
304 let condition = Some(BatchCond::Ok { step: step as u32 });
305 steps.push(BatchStep { condition, stmt });
306 }
307 steps.push(BatchStep {
308 condition: Some(BatchCond::Ok { step: count }),
309 stmt: Stmt::new("COMMIT", false),
310 });
311 steps.push(BatchStep {
312 condition: Some(BatchCond::Not {
313 cond: Box::new(BatchCond::Ok { step: count + 1 }),
314 }),
315 stmt: Stmt::new("ROLLBACK", false),
316 });
317 Batch {
318 steps,
319 replication_index: None,
320 }
321 }
322}
323
324impl FromIterator<Stmt> for Batch {
325 fn from_iter<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
326 let mut steps = Vec::new();
327 for (step, stmt) in stmts.into_iter().enumerate() {
328 let condition = if step > 0 {
329 Some(BatchCond::Ok {
330 step: (step - 1) as u32,
331 })
332 } else {
333 None
334 };
335 steps.push(BatchStep { condition, stmt });
336 }
337 Batch {
338 steps,
339 replication_index: None,
340 }
341 }
342}
343
344#[derive(Clone, Deserialize, Serialize, prost::Message)]
345pub struct BatchStep {
346 #[serde(default)]
347 #[prost(message, optional, tag = "1")]
348 pub condition: Option<BatchCond>,
349 #[prost(message, required, tag = "2")]
350 pub stmt: Stmt,
351}
352
353#[derive(Clone, Deserialize, Serialize, Debug, Default)]
354pub struct BatchResult {
355 pub step_results: Vec<Option<StmtResult>>,
356 pub step_errors: Vec<Option<Error>>,
357 #[serde(default, with = "option_u64_as_str")]
358 pub replication_index: Option<u64>,
359}
360
361#[derive(Clone, Deserialize, Serialize, Debug, Default)]
362#[serde(tag = "type", rename_all = "snake_case")]
363pub enum BatchCond {
364 #[serde(skip_deserializing)]
365 #[default]
366 None,
367 Ok {
368 step: u32,
369 },
370 Error {
371 step: u32,
372 },
373 Not {
374 cond: Box<BatchCond>,
375 },
376 And(BatchCondList),
377 Or(BatchCondList),
378 IsAutocommit {},
379}
380
381#[derive(Clone, Deserialize, Serialize, prost::Message)]
382pub struct BatchCondList {
383 #[prost(message, repeated, tag = "1")]
384 pub conds: Vec<BatchCond>,
385}
386
387#[derive(Clone, Deserialize, Serialize, Debug, Default)]
388#[serde(tag = "type", rename_all = "snake_case")]
389pub enum CursorEntry {
390 #[serde(skip_deserializing)]
391 #[default]
392 None,
393 StepBegin(StepBeginEntry),
394 StepEnd(StepEndEntry),
395 StepError(StepErrorEntry),
396 Row {
397 row: Row,
398 },
399 Error {
400 error: Error,
401 },
402 ReplicationIndex {
403 replication_index: Option<u64>,
404 },
405}
406
407#[derive(Clone, Deserialize, Serialize, prost::Message)]
408pub struct StepBeginEntry {
409 #[prost(uint32, tag = "1")]
410 pub step: u32,
411 #[prost(message, repeated, tag = "2")]
412 pub cols: Vec<Col>,
413}
414
415#[derive(Clone, Deserialize, Serialize, prost::Message)]
416pub struct StepEndEntry {
417 #[prost(uint64, tag = "1")]
418 pub affected_row_count: u64,
419 #[prost(sint64, optional, tag = "2")]
420 pub last_insert_rowid: Option<i64>,
421}
422
423#[derive(Clone, Deserialize, Serialize, prost::Message)]
424pub struct StepErrorEntry {
425 #[prost(uint32, tag = "1")]
426 pub step: u32,
427 #[prost(message, required, tag = "2")]
428 pub error: Error,
429}
430
431#[derive(Clone, Deserialize, Serialize, prost::Message)]
432pub struct DescribeResult {
433 #[prost(message, repeated, tag = "1")]
434 pub params: Vec<DescribeParam>,
435 #[prost(message, repeated, tag = "2")]
436 pub cols: Vec<DescribeCol>,
437 #[prost(bool, tag = "3")]
438 pub is_explain: bool,
439 #[prost(bool, tag = "4")]
440 pub is_readonly: bool,
441}
442
443#[derive(Clone, Deserialize, Serialize, prost::Message)]
444pub struct DescribeParam {
445 #[prost(string, optional, tag = "1")]
446 pub name: Option<String>,
447}
448
449#[derive(Clone, Deserialize, Serialize, prost::Message)]
450pub struct DescribeCol {
451 #[prost(string, tag = "1")]
452 pub name: String,
453 #[prost(string, optional, tag = "2")]
454 pub decltype: Option<String>,
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize, Default)]
458#[serde(tag = "type", rename_all = "snake_case")]
459pub enum Value {
460 #[serde(skip_deserializing)]
461 #[default]
462 None,
463 Null,
464 Integer {
465 #[serde(with = "i64_as_str")]
466 value: i64,
467 },
468 Float {
469 value: f64,
470 },
471 Text {
472 value: Arc<str>,
473 },
474 Blob {
475 #[serde(with = "bytes_as_base64", rename = "base64")]
476 value: Bytes,
477 },
478}
479
480mod i64_as_str {
481 use serde::{de, ser};
482 use serde::{de::Error as _, Serialize as _};
483
484 pub fn serialize<S: ser::Serializer>(value: &i64, ser: S) -> Result<S::Ok, S::Error> {
485 value.to_string().serialize(ser)
486 }
487
488 pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<i64, D::Error> {
489 let str_value = <&'de str as de::Deserialize>::deserialize(de)?;
490 str_value.parse().map_err(|_| {
491 D::Error::invalid_value(
492 de::Unexpected::Str(str_value),
493 &"decimal integer as a string",
494 )
495 })
496 }
497}
498
499mod option_i64_as_str {
500 use serde::de::{Error, Visitor};
501 use serde::{ser, Deserializer, Serialize as _};
502
503 pub fn serialize<S: ser::Serializer>(value: &Option<i64>, ser: S) -> Result<S::Ok, S::Error> {
504 value.map(|v| v.to_string()).serialize(ser)
505 }
506
507 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<i64>, D::Error> {
508 struct V;
509
510 impl<'de> Visitor<'de> for V {
511 type Value = Option<i64>;
512
513 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
514 write!(formatter, "a string representing a signed integer, or null")
515 }
516
517 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
518 where
519 D: Deserializer<'de>,
520 {
521 deserializer.deserialize_any(V)
522 }
523
524 fn visit_none<E>(self) -> Result<Self::Value, E>
525 where
526 E: Error,
527 {
528 Ok(None)
529 }
530
531 fn visit_unit<E>(self) -> Result<Self::Value, E>
532 where
533 E: Error,
534 {
535 Ok(None)
536 }
537
538 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
539 where
540 E: Error,
541 {
542 Ok(Some(v))
543 }
544
545 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
546 where
547 E: Error,
548 {
549 v.parse().map_err(E::custom).map(Some)
550 }
551 }
552
553 d.deserialize_option(V)
554 }
555}
556
557pub mod option_u64_as_str {
558 use serde::de::Error;
559 use serde::{de::Visitor, ser, Deserializer, Serialize as _};
560
561 pub fn serialize<S: ser::Serializer>(value: &Option<u64>, ser: S) -> Result<S::Ok, S::Error> {
562 value.map(|v| v.to_string()).serialize(ser)
563 }
564
565 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<u64>, D::Error> {
566 struct V;
567
568 impl<'de> Visitor<'de> for V {
569 type Value = Option<u64>;
570
571 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
572 write!(formatter, "a string representing an integer, or null")
573 }
574
575 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
576 where
577 D: Deserializer<'de>,
578 {
579 deserializer.deserialize_any(V)
580 }
581
582 fn visit_unit<E>(self) -> Result<Self::Value, E>
583 where
584 E: Error,
585 {
586 Ok(None)
587 }
588
589 fn visit_none<E>(self) -> Result<Self::Value, E>
590 where
591 E: Error,
592 {
593 Ok(None)
594 }
595
596 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
597 where
598 E: Error,
599 {
600 Ok(Some(v))
601 }
602
603 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
604 where
605 E: Error,
606 {
607 v.parse().map_err(E::custom).map(Some)
608 }
609 }
610
611 d.deserialize_option(V)
612 }
613
614 #[cfg(test)]
615 mod test {
616 use serde::Deserialize;
617
618 #[test]
619 fn deserialize_ok() {
620 #[derive(Deserialize)]
621 struct Test {
622 #[serde(with = "super")]
623 value: Option<u64>,
624 }
625
626 let json = r#"{"value": null }"#;
627 let val: Test = serde_json::from_str(json).unwrap();
628 assert!(val.value.is_none());
629
630 let json = r#"{"value": "124" }"#;
631 let val: Test = serde_json::from_str(json).unwrap();
632 assert_eq!(val.value.unwrap(), 124);
633
634 let json = r#"{"value": 124 }"#;
635 let val: Test = serde_json::from_str(json).unwrap();
636 assert_eq!(val.value.unwrap(), 124);
637 }
638 }
639}
640
641mod bytes_as_base64 {
642 use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _};
643 use bytes::Bytes;
644 use serde::{de, ser};
645 use serde::{de::Error as _, Serialize as _};
646
647 pub fn serialize<S: ser::Serializer>(value: &Bytes, ser: S) -> Result<S::Ok, S::Error> {
648 STANDARD_NO_PAD.encode(value).serialize(ser)
649 }
650
651 pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result<Bytes, D::Error> {
652 let text = <&'de str as de::Deserialize>::deserialize(de)?;
653 let text = text.trim_end_matches('=');
654 let bytes = STANDARD_NO_PAD.decode(text).map_err(|_| {
655 D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64")
656 })?;
657 Ok(Bytes::from(bytes))
658 }
659}