Skip to main content

sentinel_driver/pipeline/
mod.rs

1pub mod batch;
2
3use std::sync::Arc;
4
5use bytes::BytesMut;
6
7use crate::error::{Error, Result};
8use crate::protocol::backend::BackendMessage;
9use crate::protocol::frontend;
10use crate::row::{parse_command_tag, CommandResult, Row, RowDescription};
11
12/// A single query in a pipeline with its bound parameters.
13#[derive(Debug)]
14pub struct PipelineQuery {
15    pub sql: String,
16    pub param_types: Vec<u32>,
17    pub params: Vec<Option<Vec<u8>>>,
18}
19
20/// Result of a single query within a pipeline.
21#[derive(Debug)]
22pub enum QueryResult {
23    /// Query returned rows.
24    Rows(Vec<Row>),
25    /// Query returned a command result (INSERT/UPDATE/DELETE/etc).
26    Command(CommandResult),
27}
28
29impl QueryResult {
30    /// Get rows if this is a row-returning query.
31    pub fn into_rows(self) -> Result<Vec<Row>> {
32        match self {
33            QueryResult::Rows(rows) => Ok(rows),
34            QueryResult::Command(_) => Err(Error::Protocol(
35                "expected rows but got command result".to_string(),
36            )),
37        }
38    }
39
40    /// Get command result if this is a non-row query.
41    pub fn into_command(self) -> Result<CommandResult> {
42        match self {
43            QueryResult::Command(r) => Ok(r),
44            QueryResult::Rows(_) => Err(Error::Protocol(
45                "expected command result but got rows".to_string(),
46            )),
47        }
48    }
49}
50
51/// Encode a pipeline of queries into the write buffer.
52///
53/// Each query gets: Parse (unnamed) → Bind → Describe → Execute
54/// A single Sync is appended at the end (single pipeline barrier).
55pub fn encode_pipeline(buf: &mut BytesMut, queries: &[PipelineQuery]) {
56    for q in queries {
57        // Parse with unnamed statement ("")
58        let oids: Vec<u32> = q.param_types.clone();
59        frontend::parse(buf, "", &q.sql, &oids);
60
61        // Bind with unnamed portal and statement
62        let param_refs: Vec<Option<&[u8]>> = q.params.iter().map(|p| p.as_deref()).collect();
63        frontend::bind(buf, "", "", &param_refs, &[]);
64
65        // Describe portal to get RowDescription (if SELECT)
66        frontend::describe_portal(buf, "");
67
68        // Execute with no row limit
69        frontend::execute(buf, "", 0);
70    }
71
72    // Single Sync at the end — acts as pipeline barrier
73    frontend::sync(buf);
74}
75
76/// Read pipeline responses for `count` queries.
77///
78/// Expected sequence per query:
79/// - ParseComplete
80/// - BindComplete
81/// - RowDescription (or NoData for non-SELECT)
82/// - DataRow* + CommandComplete (or just CommandComplete)
83///
84/// Finally: ReadyForQuery after the Sync.
85pub(crate) async fn read_pipeline_responses(
86    conn: &mut crate::connection::stream::PgConnection,
87    count: usize,
88) -> Result<Vec<QueryResult>> {
89    let mut results = Vec::with_capacity(count);
90
91    for _ in 0..count {
92        // ParseComplete
93        expect_message(conn, "ParseComplete", |m| {
94            matches!(m, BackendMessage::ParseComplete)
95        })
96        .await?;
97
98        // BindComplete
99        expect_message(conn, "BindComplete", |m| {
100            matches!(m, BackendMessage::BindComplete)
101        })
102        .await?;
103
104        // RowDescription or NoData
105        let msg = conn.recv().await?;
106        let description = match msg {
107            BackendMessage::RowDescription { fields } => {
108                Some(Arc::new(RowDescription::new(fields)))
109            }
110            BackendMessage::NoData => None,
111            BackendMessage::ErrorResponse { fields } => {
112                return Err(Error::server(
113                    fields.severity,
114                    fields.code,
115                    fields.message,
116                    fields.detail,
117                    fields.hint,
118                    fields.position,
119                ));
120            }
121            other => {
122                return Err(Error::protocol(format!(
123                    "expected RowDescription or NoData, got {other:?}"
124                )));
125            }
126        };
127
128        // Read DataRows + CommandComplete
129        let result = read_query_result(conn, description).await?;
130        results.push(result);
131    }
132
133    // ReadyForQuery after Sync
134    let msg = conn.recv().await?;
135    match msg {
136        BackendMessage::ReadyForQuery { .. } => {}
137        BackendMessage::ErrorResponse { fields } => {
138            return Err(Error::server(
139                fields.severity,
140                fields.code,
141                fields.message,
142                fields.detail,
143                fields.hint,
144                fields.position,
145            ));
146        }
147        other => {
148            return Err(Error::protocol(format!(
149                "expected ReadyForQuery, got {other:?}"
150            )));
151        }
152    }
153
154    Ok(results)
155}
156
157/// Read DataRows and CommandComplete for a single query in the pipeline.
158async fn read_query_result(
159    conn: &mut crate::connection::stream::PgConnection,
160    description: Option<Arc<RowDescription>>,
161) -> Result<QueryResult> {
162    let mut rows = Vec::new();
163
164    loop {
165        let msg = conn.recv().await?;
166        match msg {
167            BackendMessage::DataRow { columns } => {
168                let desc = description
169                    .as_ref()
170                    .ok_or_else(|| Error::protocol("received DataRow without RowDescription"))?;
171                rows.push(Row::new(columns, Arc::clone(desc)));
172            }
173            BackendMessage::CommandComplete { tag } => {
174                if rows.is_empty() {
175                    return Ok(QueryResult::Command(parse_command_tag(&tag)));
176                } else {
177                    return Ok(QueryResult::Rows(rows));
178                }
179            }
180            BackendMessage::EmptyQueryResponse => {
181                return Ok(QueryResult::Command(CommandResult {
182                    command: String::new(),
183                    rows_affected: 0,
184                }));
185            }
186            BackendMessage::ErrorResponse { fields } => {
187                return Err(Error::server(
188                    fields.severity,
189                    fields.code,
190                    fields.message,
191                    fields.detail,
192                    fields.hint,
193                    fields.position,
194                ));
195            }
196            other => {
197                return Err(Error::protocol(format!(
198                    "unexpected message in query result: {other:?}"
199                )));
200            }
201        }
202    }
203}
204
205async fn expect_message(
206    conn: &mut crate::connection::stream::PgConnection,
207    expected: &str,
208    check: impl FnOnce(&BackendMessage) -> bool,
209) -> Result<()> {
210    let msg = conn.recv().await?;
211    if check(&msg) {
212        Ok(())
213    } else if let BackendMessage::ErrorResponse { fields } = msg {
214        Err(Error::server(
215            fields.severity,
216            fields.code,
217            fields.message,
218            fields.detail,
219            fields.hint,
220            fields.position,
221        ))
222    } else {
223        Err(Error::protocol(format!("expected {expected}, got {msg:?}")))
224    }
225}