postgres_sync/
lib.rs

1//! A synchronous, drop-in replacement for the `postgres` crate.
2//!
3//! This crate provides a compatible API but uses standard library networking instead of `tokio`.
4//!
5//! For detailed documentation on individual functions and types, please refer to the
6//! [original `postgres` crate documentation](https://docs.rs/postgres/latest/postgres/).
7//!
8//! **Note:** `postgres_sync` implements a *subset* of the `postgres` API. If you find a
9//! feature in the `postgres` docs, it may not yet be implemented in this crate.
10
11use std::error::Error as StdError;
12use std::io::{Read, Write};
13use std::net::TcpStream;
14use std::time::Duration;
15
16use bytes::BytesMut;
17use fallible_iterator::FallibleIterator;
18use postgres_protocol::Oid;
19use postgres_protocol::authentication::{
20    md5_hash,
21    sasl::{self, ChannelBinding, ScramSha256},
22};
23use postgres_protocol::message::backend;
24use postgres_protocol::message::frontend;
25use postgres_types::{IsNull, Type};
26use socket2::{SockRef, TcpKeepalive};
27
28pub use fallible_iterator;
29pub use postgres_types::{BorrowToSql, FromSql, ToSql};
30pub use postgres_types as types;
31
32pub use crate::transaction::Transaction;
33pub use crate::config::Config;
34
35mod config;
36mod transaction;
37
38pub type Error = Box<dyn StdError + Send + Sync>;
39
40#[derive(Debug)]
41pub struct DbError {
42    severity: String,
43    code: String,
44    message: String,
45    detail: Option<String>,
46    hint: Option<String>,
47    position: Option<ErrorPosition>,
48}
49
50impl std::fmt::Display for DbError {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{}: {} ({})", self.severity, self.message, self.code)?;
53        if let Some(detail) = &self.detail {
54            write!(f, "\nDETAIL: {detail}")?;
55        }
56        if let Some(hint) = &self.hint {
57            write!(f, "\nHINT: {hint}")?;
58        }
59        if let Some(pos) = &self.position {
60            write!(f, "\nPOSITION: {pos:?}")?;
61        }
62        Ok(())
63    }
64}
65
66impl StdError for DbError {}
67
68#[derive(Debug)]
69pub enum ErrorPosition {
70    Original(u32),
71    Internal { position: u32, query: String },
72}
73
74#[derive(Debug, Clone, Copy)]
75pub struct NoTls;
76
77pub struct Client {
78    stream: TcpStream,
79    read_buf: BytesMut,
80    write_buf: BytesMut,
81}
82
83impl Client {
84    pub fn connect(s: &str, _tls: NoTls) -> Result<Client, Error> {
85        let config = config::Config::parse(s)?;
86        Self::connect_config(&config, _tls)
87    }
88
89    fn connect_config(config: &config::Config, _tls: NoTls) -> Result<Client, Error> {
90        let stream = TcpStream::connect((config.host.as_str(), config.port))?;
91
92        let sock_ref = SockRef::from(&stream);
93        let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(50));
94        sock_ref.set_tcp_keepalive(&keepalive)?;
95
96        let user = &config.user;
97        let db = &config.db;
98
99        let mut this = Client {
100            stream,
101            read_buf: BytesMut::with_capacity(8192),
102            write_buf: BytesMut::with_capacity(8192),
103        };
104
105        let mut params: Vec<(&str, &str)> = Vec::new();
106        params.push(("user", user));
107        if !db.is_empty() {
108            params.push(("database", db));
109        }
110        params.push(("client_encoding", "UTF8"));
111
112        frontend::startup_message(params.iter().copied(), &mut this.write_buf)?;
113        this.flush()?;
114
115        this.handle_auth(user.as_bytes(), &config.password)?;
116
117        loop {
118            match this.read_message()? {
119                backend::Message::ReadyForQuery(_) => break,
120                backend::Message::BackendKeyData(_) => {}
121                backend::Message::ParameterStatus(_) => {}
122                backend::Message::ErrorResponse(body) => return Err(this.error_response(body).into()),
123                _ => return Err("unexpected message".into()),
124            }
125        }
126
127        Ok(this)
128    }
129
130    fn handle_auth(&mut self, user: &[u8], password: &str) -> Result<(), Error> {
131        loop {
132            match self.read_message()? {
133                backend::Message::AuthenticationOk => break,
134                backend::Message::AuthenticationCleartextPassword => {
135                    // TODO: untested
136                    frontend::password_message(password.as_bytes(), &mut self.write_buf)?;
137                    self.flush()?;
138                }
139                backend::Message::AuthenticationMd5Password(body) => {
140                    // TODO: untested
141                    let output = md5_hash(user, password.as_bytes(), body.salt());
142                    frontend::password_message(output.as_bytes(), &mut self.write_buf)?;
143                    self.flush()?;
144                }
145                backend::Message::AuthenticationSasl(body) => {
146                    let mut has_scram = false;
147                    let mut mechs = body.mechanisms();
148                    while let Some(mech) = mechs.next()? {
149                        if mech == sasl::SCRAM_SHA_256 {
150                            has_scram = true;
151                        }
152                    }
153                    if !has_scram {
154                        return Err("unsupported authentication".into());
155                    }
156
157                    let mut scram =
158                        ScramSha256::new(password.as_bytes(), ChannelBinding::unsupported());
159
160                    frontend::sasl_initial_response(
161                        sasl::SCRAM_SHA_256,
162                        scram.message(),
163                        &mut self.write_buf,
164                    )?;
165                    self.flush()?;
166
167                    let body = match self.read_message()? {
168                        backend::Message::AuthenticationSaslContinue(body) => body,
169                        backend::Message::ErrorResponse(body) => return Err(self.error_response(body).into()),
170                        _ => return Err("unexpected message".into()),
171                    };
172
173                    scram.update(body.data())?;
174
175                    frontend::sasl_response(scram.message(), &mut self.write_buf)?;
176                    self.flush()?;
177
178                    let body = match self.read_message()? {
179                        backend::Message::AuthenticationSaslFinal(body) => body,
180                        backend::Message::ErrorResponse(body) => return Err(self.error_response(body).into()),
181                        _ => return Err("unexpected message".into()),
182                    };
183
184                    scram.finish(body.data())?;
185                }
186                backend::Message::ErrorResponse(body) => {
187                    return Err(self.error_response(body).into());
188                }
189                _ => return Err("unsupported authentication".into()),
190            }
191        }
192        Ok(())
193    }
194
195    fn flush(&mut self) -> Result<(), Error> {
196        self.stream.write_all(&self.write_buf)?;
197        self.stream.flush()?;
198        self.write_buf.clear();
199        Ok(())
200    }
201
202    fn read_message(&mut self) -> Result<backend::Message, Error> {
203        loop {
204            if let Some(message) = backend::Message::parse(&mut self.read_buf)? {
205                return Ok(message);
206            }
207            let mut buf = [0u8; 8192];
208            let n = self.stream.read(&mut buf)?;
209            if n == 0 {
210                return Err("unexpected EOF".into());
211            }
212            self.read_buf.extend_from_slice(&buf[..n]);
213        }
214    }
215
216    fn error_response(&self, body: backend::ErrorResponseBody) -> DbError {
217        let mut severity = String::new();
218        let mut code = String::new();
219        let mut message = String::new();
220        let mut detail = None;
221        let mut hint = None;
222        let mut normal_position = None;
223        let mut internal_position = None;
224        let mut internal_query = None;
225        let mut fields = body.fields();
226        while let Some(field) = fields.next().unwrap() {
227            match field.type_() {
228                b'S' => severity = String::from_utf8_lossy(field.value_bytes()).into_owned(),
229                b'C' => code = String::from_utf8_lossy(field.value_bytes()).into_owned(),
230                b'M' => message = String::from_utf8_lossy(field.value_bytes()).into_owned(),
231                b'D' => detail = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
232                b'H' => hint = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
233                b'P' => normal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
234                b'p' => internal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
235                b'q' => internal_query = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
236                _ => {}
237            }
238        }
239        let position = match normal_position {
240            Some(pos) => Some(ErrorPosition::Original(pos)),
241            None => internal_position.map(|pos| ErrorPosition::Internal {
242                position: pos,
243                query: internal_query.unwrap_or_default(),
244            }),
245        };
246        DbError { severity, code, message, detail, hint, position }
247    }
248
249    fn drain_ready(&mut self) -> Result<(), Error> {
250        loop {
251            match self.read_message()? {
252                backend::Message::ReadyForQuery(_) => return Ok(()),
253                backend::Message::ErrorResponse(body) => {
254                    return Err(self.error_response(body).into())
255                }
256                _ => {}
257            }
258        }
259    }
260
261    #[allow(clippy::type_complexity)]
262    fn prepare_query(
263        &mut self,
264        query: &str,
265        params_len: usize,
266    ) -> Result<(Vec<Type>, Vec<(String, Oid)>), Error> {
267        let param_oids = vec![0; params_len];
268        frontend::parse("", query, param_oids.iter().copied(), &mut self.write_buf)?;
269        frontend::describe(b'S', "", &mut self.write_buf)?;
270        frontend::sync(&mut self.write_buf);
271        self.flush()?;
272
273        let mut param_types = Vec::new();
274        let mut columns = Vec::new();
275        loop {
276            match self.read_message()? {
277                backend::Message::ParseComplete => {}
278                backend::Message::ParameterDescription(body) => {
279                    let mut it = body.parameters();
280                    while let Some(oid) = it.next()? {
281                        let ty = Type::from_oid(oid).unwrap_or(Type::TEXT);
282                        param_types.push(ty);
283                    }
284                }
285                backend::Message::RowDescription(body) => {
286                    let mut fields = body.fields();
287                    while let Some(field) = fields.next()? {
288                        columns.push((field.name().to_string(), field.type_oid()));
289                    }
290                }
291                backend::Message::NoData => {}
292                backend::Message::ReadyForQuery(_) => break,
293                backend::Message::ErrorResponse(body) => {
294                    let err = self.error_response(body);
295                    self.drain_ready()?;
296                    return Err(err.into());
297                }
298                _ => return Err("unexpected message".into()),
299            }
300        }
301
302        Ok((param_types, columns))
303    }
304
305    fn bind_execute<P, I>(
306        &mut self,
307        params: I,
308        param_types: &[Type],
309        mut rows: Option<&mut Vec<Vec<Option<Vec<u8>>>>>,
310    ) -> Result<u64, Error>
311    where
312        P: BorrowToSql,
313        I: IntoIterator<Item = P>,
314        I::IntoIter: ExactSizeIterator,
315    {
316        let params: Vec<P> = params.into_iter().collect();
317        assert_eq!(param_types.len(), params.len());
318        let param_formats: Vec<i16> = params
319            .iter()
320            .zip(param_types)
321            .map(|(p, t)| p.borrow_to_sql().encode_format(t) as i16)
322            .collect();
323
324        frontend::bind(
325            "",
326            "",
327            param_formats,
328            params.iter().zip(param_types.iter()),
329            |(param, ty), buf| match param.borrow_to_sql().to_sql_checked(ty, buf)? {
330                IsNull::No => Ok(postgres_protocol::IsNull::No),
331                IsNull::Yes => Ok(postgres_protocol::IsNull::Yes),
332            },
333            Some(1),
334            &mut self.write_buf,
335        )
336        .map_err(|e| match e {
337            frontend::BindError::Conversion(e) => e,
338            frontend::BindError::Serialization(e) => Box::new(e) as Error,
339        })?;
340        frontend::execute("", 0, &mut self.write_buf)?;
341        frontend::sync(&mut self.write_buf);
342        self.flush()?;
343
344        let mut rows_affected = 0;
345        loop {
346            match self.read_message()? {
347                backend::Message::BindComplete => {}
348                backend::Message::DataRow(body) => {
349                    if let Some(out) = rows.as_mut() {
350                        out.push(self.parse_data_row(body)?);
351                    }
352                }
353                backend::Message::CommandComplete(body) => {
354                    let tag = body.tag().map_err(|e| Box::new(e) as Error)?;
355                    rows_affected = tag
356                        .rsplit(' ')
357                        .next()
358                        .and_then(|s| s.parse().ok())
359                        .unwrap_or(0);
360                }
361                backend::Message::EmptyQueryResponse => rows_affected = 0,
362                backend::Message::ReadyForQuery(_) => return Ok(rows_affected),
363                backend::Message::ErrorResponse(body) => {
364                    let err = self.error_response(body);
365                    self.drain_ready()?;
366                    return Err(err.into());
367                }
368                _ => return Err("unexpected message".into()),
369            }
370        }
371    }
372
373    pub fn query_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter, Error>
374    where
375        P: BorrowToSql,
376        I: IntoIterator<Item = P>,
377        I::IntoIter: ExactSizeIterator,
378    {
379        let params = params.into_iter();
380        let (param_types, columns) = self.prepare_query(query, params.len())?;
381        let params: Vec<P> = params.collect();
382        let mut rows = Vec::new();
383        self.bind_execute(params, &param_types, Some(&mut rows))?;
384
385        Ok(RowIter {
386            columns,
387            rows: rows.into_iter(),
388        })
389    }
390
391    pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error> {
392        let (param_types, _) = self.prepare_query(query, params.len())?;
393        self.bind_execute(params.iter().copied(), &param_types, None)
394    }
395
396    pub fn query(
397        &mut self,
398        query: &str,
399        params: &[&(dyn ToSql + Sync)],
400    ) -> Result<Vec<Row>, Error> {
401        self.query_raw(query, params.iter().copied())?.collect()
402    }
403
404    pub fn query_one(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error> {
405        let mut it = self.query_raw(query, params.iter().copied())?;
406        let first = it.next()?.ok_or("no rows returned")?;
407        if it.next()?.is_some() {
408            return Err("more than one row returned".into());
409        }
410        Ok(first)
411    }
412
413    pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
414        frontend::query(query, &mut self.write_buf)?;
415        self.flush()?;
416
417        loop {
418            match self.read_message()? {
419                backend::Message::ReadyForQuery(_) => return Ok(()),
420                backend::Message::CommandComplete(_)
421                | backend::Message::EmptyQueryResponse
422                | backend::Message::RowDescription(_)
423                | backend::Message::DataRow(_) => {}
424                backend::Message::ErrorResponse(body) => {
425                    let err = self.error_response(body);
426                    self.drain_ready()?;
427                    return Err(err.into());
428                }
429                _ => return Err("unexpected message".into()),
430            }
431        }
432    }
433
434    fn parse_data_row(&self, body: backend::DataRowBody) -> Result<Vec<Option<Vec<u8>>>, Error> {
435        let mut out = Vec::new();
436        let mut ranges = body.ranges();
437        let buf = body.buffer();
438        while let Some(range) = ranges.next()? {
439            match range {
440                Some(r) => out.push(Some(buf[r].to_vec())),
441                None => out.push(None),
442            }
443        }
444        Ok(out)
445    }
446}
447
448pub struct Row {
449    columns: Vec<(String, Oid)>,
450    values: Vec<Option<Vec<u8>>>,
451}
452
453pub trait RowIndex {
454    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize>;
455}
456
457impl RowIndex for usize {
458    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
459        if *self < columns.len() { Some(*self) } else { None }
460    }
461}
462
463impl RowIndex for &str {
464    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
465        columns.iter()
466            .position(|(name, _)| name == self)
467        .or_else(|| columns.iter()
468            .position(|(name, _)| name.eq_ignore_ascii_case(self)))
469    }
470}
471
472
473impl Row {
474    pub fn get<'a, T>(&'a self, idx: impl RowIndex) -> T
475    where
476        T: FromSql<'a>,
477    {
478        let idx = idx.idx(&self.columns).expect("invalid column");
479        let (_, oid) = &self.columns[idx];
480        let ty = Type::from_oid(*oid).unwrap_or(Type::TEXT);
481        let raw = self.values[idx].as_deref();
482        FromSql::from_sql_nullable(&ty, raw).unwrap()
483    }
484}
485
486pub struct RowIter {
487    columns: Vec<(String, Oid)>,
488    rows: std::vec::IntoIter<Vec<Option<Vec<u8>>>>,
489}
490
491impl FallibleIterator for RowIter {
492    type Item = Row;
493    type Error = Error;
494
495    fn next(&mut self) -> Result<Option<Row>, Error> {
496        Ok(self.rows.next().map(|values| Row {
497            columns: self.columns.clone(),
498            values,
499        }))
500    }
501}
502