postgres_sync/
lib.rs

1//! A synchronous, drop-in replacement for the `postgres` crate.
2//!
3//! This crate provides a compatible API but uses standard library networking instead of `tokio`.
4//!
5//! For detailed documentation on individual functions and types, please refer to the
6//! [original `postgres` crate documentation](https://docs.rs/postgres/latest/postgres/).
7//!
8//! **Note:** `postgres_sync` implements a *subset* of the `postgres` API. If you find a
9//! feature in the `postgres` docs, it may not yet be implemented in this crate.
10
11use std::error::Error as StdError;
12use std::io::{Read, Write};
13use std::net::TcpStream;
14use std::time::Duration;
15
16use bytes::BytesMut;
17use fallible_iterator::FallibleIterator;
18use postgres_protocol::Oid;
19use postgres_protocol::authentication::{
20    md5_hash,
21    sasl::{self, ChannelBinding, ScramSha256},
22};
23use postgres_protocol::message::backend;
24use postgres_protocol::message::frontend;
25use postgres_types::{IsNull, Type};
26use socket2::{SockRef, TcpKeepalive};
27
28pub use fallible_iterator;
29pub use postgres_types::{BorrowToSql, FromSql, ToSql};
30pub use postgres_types as types;
31
32pub use crate::transaction::Transaction;
33pub use crate::config::Config;
34
35mod config;
36mod transaction;
37
38pub type Error = Box<dyn StdError + Send + Sync>;
39
40#[derive(Debug)]
41pub struct DbError {
42    severity: String,
43    code: String,
44    message: String,
45    detail: Option<String>,
46    hint: Option<String>,
47    position: Option<ErrorPosition>,
48}
49
50impl std::fmt::Display for DbError {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{}: {} ({})", self.severity, self.message, self.code)?;
53        if let Some(detail) = &self.detail {
54            write!(f, "\nDETAIL: {detail}")?;
55        }
56        if let Some(hint) = &self.hint {
57            write!(f, "\nHINT: {hint}")?;
58        }
59        if let Some(pos) = &self.position {
60            write!(f, "\nPOSITION: {pos:?}")?;
61        }
62        Ok(())
63    }
64}
65
66impl StdError for DbError {}
67
68impl DbError {
69    fn parse(mut fields: backend::ErrorFields<'_>) -> Self {
70    let mut severity = String::new();
71    let mut code = String::new();
72    let mut message = String::new();
73    let mut detail = None;
74    let mut hint = None;
75    let mut normal_position = None;
76    let mut internal_position = None;
77    let mut internal_query = None;
78    while let Some(field) = fields.next().unwrap() {
79        match field.type_() {
80            b'S' => severity = String::from_utf8_lossy(field.value_bytes()).into_owned(),
81            b'C' => code = String::from_utf8_lossy(field.value_bytes()).into_owned(),
82            b'M' => message = String::from_utf8_lossy(field.value_bytes()).into_owned(),
83            b'D' => detail = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
84            b'H' => hint = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
85            b'P' => normal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
86            b'p' => internal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
87            b'q' => internal_query = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
88            _ => {}
89        }
90    }
91    let position = match normal_position {
92        Some(pos) => Some(ErrorPosition::Original(pos)),
93        None => internal_position.map(|pos| ErrorPosition::Internal {
94            position: pos,
95            query: internal_query.unwrap_or_default(),
96        }),
97    };
98    Self { severity, code, message, detail, hint, position }
99    }
100}
101
102#[derive(Debug)]
103pub enum ErrorPosition {
104    Original(u32),
105    Internal { position: u32, query: String },
106}
107
108#[derive(Debug, Clone, Copy)]
109pub struct NoTls;
110
111pub struct Client {
112    stream: TcpStream,
113    read_buf: BytesMut,
114    write_buf: BytesMut,
115}
116
117impl Client {
118    pub fn connect(s: &str, _tls: NoTls) -> Result<Client, Error> {
119        let config = config::Config::parse(s)?;
120        Self::connect_config(&config, _tls)
121    }
122
123    fn connect_config(config: &config::Config, _tls: NoTls) -> Result<Client, Error> {
124        let stream = TcpStream::connect((config.host.as_str(), config.port))?;
125
126        let sock_ref = SockRef::from(&stream);
127        let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(50));
128        sock_ref.set_tcp_keepalive(&keepalive)?;
129
130        let user = &config.user;
131        let db = &config.db;
132
133        let mut this = Client {
134            stream,
135            read_buf: BytesMut::with_capacity(8192),
136            write_buf: BytesMut::with_capacity(8192),
137        };
138
139        let mut params: Vec<(&str, &str)> = Vec::new();
140        params.push(("user", user));
141        if !db.is_empty() {
142            params.push(("database", db));
143        }
144        params.push(("client_encoding", "UTF8"));
145
146        frontend::startup_message(params.iter().copied(), &mut this.write_buf)?;
147        this.flush()?;
148
149        this.handle_auth(user.as_bytes(), &config.password)?;
150
151        loop {
152            match this.read_message()? {
153                backend::Message::ReadyForQuery(_) => break,
154                backend::Message::BackendKeyData(_) => {}
155                backend::Message::ParameterStatus(_) => {}
156                backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
157                _ => return Err("unexpected message".into()),
158            }
159        }
160
161        Ok(this)
162    }
163
164    fn handle_auth(&mut self, user: &[u8], password: &str) -> Result<(), Error> {
165        loop {
166            match self.read_message()? {
167                backend::Message::AuthenticationOk => break,
168                backend::Message::AuthenticationCleartextPassword => {
169                    // TODO: untested
170                    frontend::password_message(password.as_bytes(), &mut self.write_buf)?;
171                    self.flush()?;
172                }
173                backend::Message::AuthenticationMd5Password(body) => {
174                    // TODO: untested
175                    let output = md5_hash(user, password.as_bytes(), body.salt());
176                    frontend::password_message(output.as_bytes(), &mut self.write_buf)?;
177                    self.flush()?;
178                }
179                backend::Message::AuthenticationSasl(body) => {
180                    let mut has_scram = false;
181                    let mut mechs = body.mechanisms();
182                    while let Some(mech) = mechs.next()? {
183                        if mech == sasl::SCRAM_SHA_256 {
184                            has_scram = true;
185                        }
186                    }
187                    if !has_scram {
188                        return Err("unsupported authentication".into());
189                    }
190
191                    let mut scram =
192                        ScramSha256::new(password.as_bytes(), ChannelBinding::unsupported());
193
194                    frontend::sasl_initial_response(
195                        sasl::SCRAM_SHA_256,
196                        scram.message(),
197                        &mut self.write_buf,
198                    )?;
199                    self.flush()?;
200
201                    let body = match self.read_message()? {
202                        backend::Message::AuthenticationSaslContinue(body) => body,
203                        backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
204                        _ => return Err("unexpected message".into()),
205                    };
206
207                    scram.update(body.data())?;
208
209                    frontend::sasl_response(scram.message(), &mut self.write_buf)?;
210                    self.flush()?;
211
212                    let body = match self.read_message()? {
213                        backend::Message::AuthenticationSaslFinal(body) => body,
214                        backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
215                        _ => return Err("unexpected message".into()),
216                    };
217
218                    scram.finish(body.data())?;
219                }
220                backend::Message::ErrorResponse(body) => {
221                    return Err(DbError::parse(body.fields()).into());
222                }
223                _ => return Err("unsupported authentication".into()),
224            }
225        }
226        Ok(())
227    }
228
229    fn flush(&mut self) -> Result<(), Error> {
230        self.stream.write_all(&self.write_buf)?;
231        self.stream.flush()?;
232        self.write_buf.clear();
233        Ok(())
234    }
235
236    fn read_message(&mut self) -> Result<backend::Message, Error> {
237        loop {
238            if let Some(message) = backend::Message::parse(&mut self.read_buf)? {
239                if let backend::Message::NoticeResponse(body) = &message {
240                    log::info!("postgres notice: {:?}", DbError::parse(body.fields()));
241                    continue;
242                }
243                return Ok(message);
244            }
245            let mut buf = [0u8; 8192];
246            let n = self.stream.read(&mut buf)?;
247            if n == 0 {
248                return Err("unexpected EOF".into());
249            }
250            self.read_buf.extend_from_slice(&buf[..n]);
251        }
252    }
253
254    fn drain_ready(&mut self) -> Result<(), Error> {
255        loop {
256            match self.read_message()? {
257                backend::Message::ReadyForQuery(_) => return Ok(()),
258                backend::Message::ErrorResponse(body) => {
259                    return Err(DbError::parse(body.fields()).into())
260                }
261                _ => {}
262            }
263        }
264    }
265
266    #[allow(clippy::type_complexity)]
267    fn prepare_query(
268        &mut self,
269        query: &str,
270        params_len: usize,
271    ) -> Result<(Vec<Type>, Vec<(String, Oid)>), Error> {
272        let param_oids = vec![0; params_len];
273        frontend::parse("", query, param_oids.iter().copied(), &mut self.write_buf)?;
274        frontend::describe(b'S', "", &mut self.write_buf)?;
275        frontend::sync(&mut self.write_buf);
276        self.flush()?;
277
278        let mut param_types = Vec::new();
279        let mut columns = Vec::new();
280        loop {
281            match self.read_message()? {
282                backend::Message::ParseComplete => {}
283                backend::Message::ParameterDescription(body) => {
284                    let mut it = body.parameters();
285                    while let Some(oid) = it.next()? {
286                        let ty = Type::from_oid(oid).unwrap_or(Type::TEXT);
287                        param_types.push(ty);
288                    }
289                }
290                backend::Message::RowDescription(body) => {
291                    let mut fields = body.fields();
292                    while let Some(field) = fields.next()? {
293                        columns.push((field.name().to_string(), field.type_oid()));
294                    }
295                }
296                backend::Message::NoData => {}
297                backend::Message::ReadyForQuery(_) => break,
298                backend::Message::ErrorResponse(body) => {
299                    let err = DbError::parse(body.fields());
300                    self.drain_ready()?;
301                    return Err(err.into());
302                }
303                _ => return Err("unexpected message".into()),
304            }
305        }
306
307        Ok((param_types, columns))
308    }
309
310    fn bind_execute<P, I>(
311        &mut self,
312        params: I,
313        param_types: &[Type],
314        mut rows: Option<&mut Vec<Vec<Option<Vec<u8>>>>>,
315    ) -> Result<u64, Error>
316    where
317        P: BorrowToSql,
318        I: IntoIterator<Item = P>,
319        I::IntoIter: ExactSizeIterator,
320    {
321        let params: Vec<P> = params.into_iter().collect();
322        assert_eq!(param_types.len(), params.len());
323        let param_formats: Vec<i16> = params
324            .iter()
325            .zip(param_types)
326            .map(|(p, t)| p.borrow_to_sql().encode_format(t) as i16)
327            .collect();
328
329        frontend::bind(
330            "",
331            "",
332            param_formats,
333            params.iter().zip(param_types.iter()),
334            |(param, ty), buf| match param.borrow_to_sql().to_sql_checked(ty, buf)? {
335                IsNull::No => Ok(postgres_protocol::IsNull::No),
336                IsNull::Yes => Ok(postgres_protocol::IsNull::Yes),
337            },
338            Some(1),
339            &mut self.write_buf,
340        )
341        .map_err(|e| match e {
342            frontend::BindError::Conversion(e) => e,
343            frontend::BindError::Serialization(e) => Box::new(e) as Error,
344        })?;
345        frontend::execute("", 0, &mut self.write_buf)?;
346        frontend::sync(&mut self.write_buf);
347        self.flush()?;
348
349        let mut rows_affected = 0;
350        loop {
351            match self.read_message()? {
352                backend::Message::BindComplete => {}
353                backend::Message::DataRow(body) => {
354                    if let Some(out) = rows.as_mut() {
355                        out.push(self.parse_data_row(body)?);
356                    }
357                }
358                backend::Message::CommandComplete(body) => {
359                    let tag = body.tag().map_err(|e| Box::new(e) as Error)?;
360                    rows_affected = tag
361                        .rsplit(' ')
362                        .next()
363                        .and_then(|s| s.parse().ok())
364                        .unwrap_or(0);
365                }
366                backend::Message::EmptyQueryResponse => rows_affected = 0,
367                backend::Message::ReadyForQuery(_) => return Ok(rows_affected),
368                backend::Message::ErrorResponse(body) => {
369                    let err = DbError::parse(body.fields());
370                    self.drain_ready()?;
371                    return Err(err.into());
372                }
373                _ => return Err("unexpected message".into()),
374            }
375        }
376    }
377
378    pub fn query_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter, Error>
379    where
380        P: BorrowToSql,
381        I: IntoIterator<Item = P>,
382        I::IntoIter: ExactSizeIterator,
383    {
384        let params = params.into_iter();
385        let (param_types, columns) = self.prepare_query(query, params.len())?;
386        let params: Vec<P> = params.collect();
387        let mut rows = Vec::new();
388        self.bind_execute(params, &param_types, Some(&mut rows))?;
389
390        Ok(RowIter {
391            columns,
392            rows: rows.into_iter(),
393        })
394    }
395
396    pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error> {
397        let (param_types, _) = self.prepare_query(query, params.len())?;
398        self.bind_execute(params.iter().copied(), &param_types, None)
399    }
400
401    pub fn query(
402        &mut self,
403        query: &str,
404        params: &[&(dyn ToSql + Sync)],
405    ) -> Result<Vec<Row>, Error> {
406        self.query_raw(query, params.iter().copied())?.collect()
407    }
408
409    pub fn query_one(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error> {
410        let mut it = self.query_raw(query, params.iter().copied())?;
411        let first = it.next()?.ok_or("no rows returned")?;
412        if it.next()?.is_some() {
413            return Err("more than one row returned".into());
414        }
415        Ok(first)
416    }
417
418    pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
419        frontend::query(query, &mut self.write_buf)?;
420        self.flush()?;
421
422        loop {
423            match self.read_message()? {
424                backend::Message::ReadyForQuery(_) => return Ok(()),
425                backend::Message::CommandComplete(_)
426                | backend::Message::EmptyQueryResponse
427                | backend::Message::RowDescription(_)
428                | backend::Message::DataRow(_) => {}
429                backend::Message::ErrorResponse(body) => {
430                    let err = DbError::parse(body.fields());
431                    self.drain_ready()?;
432                    return Err(err.into());
433                }
434                _ => return Err("unexpected message".into()),
435            }
436        }
437    }
438
439    fn parse_data_row(&self, body: backend::DataRowBody) -> Result<Vec<Option<Vec<u8>>>, Error> {
440        let mut out = Vec::new();
441        let mut ranges = body.ranges();
442        let buf = body.buffer();
443        while let Some(range) = ranges.next()? {
444            match range {
445                Some(r) => out.push(Some(buf[r].to_vec())),
446                None => out.push(None),
447            }
448        }
449        Ok(out)
450    }
451}
452
453pub struct Row {
454    columns: Vec<(String, Oid)>,
455    values: Vec<Option<Vec<u8>>>,
456}
457
458pub trait RowIndex {
459    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize>;
460}
461
462impl RowIndex for usize {
463    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
464        if *self < columns.len() { Some(*self) } else { None }
465    }
466}
467
468impl RowIndex for &str {
469    fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
470        columns.iter()
471            .position(|(name, _)| name == self)
472        .or_else(|| columns.iter()
473            .position(|(name, _)| name.eq_ignore_ascii_case(self)))
474    }
475}
476
477
478impl Row {
479    pub fn get<'a, T>(&'a self, idx: impl RowIndex) -> T
480    where
481        T: FromSql<'a>,
482    {
483        let idx = idx.idx(&self.columns).expect("invalid column");
484        let (_, oid) = &self.columns[idx];
485        let ty = Type::from_oid(*oid).unwrap_or(Type::TEXT);
486        let raw = self.values[idx].as_deref();
487        FromSql::from_sql_nullable(&ty, raw).unwrap()
488    }
489}
490
491pub struct RowIter {
492    columns: Vec<(String, Oid)>,
493    rows: std::vec::IntoIter<Vec<Option<Vec<u8>>>>,
494}
495
496impl FallibleIterator for RowIter {
497    type Item = Row;
498    type Error = Error;
499
500    fn next(&mut self) -> Result<Option<Row>, Error> {
501        Ok(self.rows.next().map(|values| Row {
502            columns: self.columns.clone(),
503            values,
504        }))
505    }
506}
507