1use std::error::Error as StdError;
12use std::io::{Read, Write};
13use std::net::TcpStream;
14use std::time::Duration;
15
16use bytes::BytesMut;
17use fallible_iterator::FallibleIterator;
18use postgres_protocol::Oid;
19use postgres_protocol::authentication::{
20 md5_hash,
21 sasl::{self, ChannelBinding, ScramSha256},
22};
23use postgres_protocol::message::backend;
24use postgres_protocol::message::frontend;
25use postgres_types::{IsNull, Type};
26use socket2::{SockRef, TcpKeepalive};
27
28pub use fallible_iterator;
29pub use postgres_types::{BorrowToSql, FromSql, ToSql};
30pub use postgres_types as types;
31
32pub use crate::transaction::Transaction;
33pub use crate::config::Config;
34
35mod config;
36mod transaction;
37
38pub type Error = Box<dyn StdError + Send + Sync>;
39
40#[derive(Debug)]
41pub struct DbError {
42 severity: String,
43 code: String,
44 message: String,
45 detail: Option<String>,
46 hint: Option<String>,
47 position: Option<ErrorPosition>,
48}
49
50impl std::fmt::Display for DbError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}: {} ({})", self.severity, self.message, self.code)?;
53 if let Some(detail) = &self.detail {
54 write!(f, "\nDETAIL: {detail}")?;
55 }
56 if let Some(hint) = &self.hint {
57 write!(f, "\nHINT: {hint}")?;
58 }
59 if let Some(pos) = &self.position {
60 write!(f, "\nPOSITION: {pos:?}")?;
61 }
62 Ok(())
63 }
64}
65
66impl StdError for DbError {}
67
68impl DbError {
69 fn parse(mut fields: backend::ErrorFields<'_>) -> Self {
70 let mut severity = String::new();
71 let mut code = String::new();
72 let mut message = String::new();
73 let mut detail = None;
74 let mut hint = None;
75 let mut normal_position = None;
76 let mut internal_position = None;
77 let mut internal_query = None;
78 while let Some(field) = fields.next().unwrap() {
79 match field.type_() {
80 b'S' => severity = String::from_utf8_lossy(field.value_bytes()).into_owned(),
81 b'C' => code = String::from_utf8_lossy(field.value_bytes()).into_owned(),
82 b'M' => message = String::from_utf8_lossy(field.value_bytes()).into_owned(),
83 b'D' => detail = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
84 b'H' => hint = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
85 b'P' => normal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
86 b'p' => internal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
87 b'q' => internal_query = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
88 _ => {}
89 }
90 }
91 let position = match normal_position {
92 Some(pos) => Some(ErrorPosition::Original(pos)),
93 None => internal_position.map(|pos| ErrorPosition::Internal {
94 position: pos,
95 query: internal_query.unwrap_or_default(),
96 }),
97 };
98 Self { severity, code, message, detail, hint, position }
99 }
100}
101
102#[derive(Debug)]
103pub enum ErrorPosition {
104 Original(u32),
105 Internal { position: u32, query: String },
106}
107
108#[derive(Debug, Clone, Copy)]
109pub struct NoTls;
110
111pub struct Client {
112 stream: TcpStream,
113 read_buf: BytesMut,
114 write_buf: BytesMut,
115}
116
117impl Client {
118 pub fn connect(s: &str, _tls: NoTls) -> Result<Client, Error> {
119 let config = config::Config::parse(s)?;
120 Self::connect_config(&config, _tls)
121 }
122
123 fn connect_config(config: &config::Config, _tls: NoTls) -> Result<Client, Error> {
124 let stream = TcpStream::connect((config.host.as_str(), config.port))?;
125
126 let sock_ref = SockRef::from(&stream);
127 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(50));
128 sock_ref.set_tcp_keepalive(&keepalive)?;
129
130 let user = &config.user;
131 let db = &config.db;
132
133 let mut this = Client {
134 stream,
135 read_buf: BytesMut::with_capacity(8192),
136 write_buf: BytesMut::with_capacity(8192),
137 };
138
139 let mut params: Vec<(&str, &str)> = Vec::new();
140 params.push(("user", user));
141 if !db.is_empty() {
142 params.push(("database", db));
143 }
144 params.push(("client_encoding", "UTF8"));
145
146 frontend::startup_message(params.iter().copied(), &mut this.write_buf)?;
147 this.flush()?;
148
149 this.handle_auth(user.as_bytes(), &config.password)?;
150
151 loop {
152 match this.read_message()? {
153 backend::Message::ReadyForQuery(_) => break,
154 backend::Message::BackendKeyData(_) => {}
155 backend::Message::ParameterStatus(_) => {}
156 backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
157 _ => return Err("unexpected message".into()),
158 }
159 }
160
161 Ok(this)
162 }
163
164 fn handle_auth(&mut self, user: &[u8], password: &str) -> Result<(), Error> {
165 loop {
166 match self.read_message()? {
167 backend::Message::AuthenticationOk => break,
168 backend::Message::AuthenticationCleartextPassword => {
169 frontend::password_message(password.as_bytes(), &mut self.write_buf)?;
171 self.flush()?;
172 }
173 backend::Message::AuthenticationMd5Password(body) => {
174 let output = md5_hash(user, password.as_bytes(), body.salt());
176 frontend::password_message(output.as_bytes(), &mut self.write_buf)?;
177 self.flush()?;
178 }
179 backend::Message::AuthenticationSasl(body) => {
180 let mut has_scram = false;
181 let mut mechs = body.mechanisms();
182 while let Some(mech) = mechs.next()? {
183 if mech == sasl::SCRAM_SHA_256 {
184 has_scram = true;
185 }
186 }
187 if !has_scram {
188 return Err("unsupported authentication".into());
189 }
190
191 let mut scram =
192 ScramSha256::new(password.as_bytes(), ChannelBinding::unsupported());
193
194 frontend::sasl_initial_response(
195 sasl::SCRAM_SHA_256,
196 scram.message(),
197 &mut self.write_buf,
198 )?;
199 self.flush()?;
200
201 let body = match self.read_message()? {
202 backend::Message::AuthenticationSaslContinue(body) => body,
203 backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
204 _ => return Err("unexpected message".into()),
205 };
206
207 scram.update(body.data())?;
208
209 frontend::sasl_response(scram.message(), &mut self.write_buf)?;
210 self.flush()?;
211
212 let body = match self.read_message()? {
213 backend::Message::AuthenticationSaslFinal(body) => body,
214 backend::Message::ErrorResponse(body) => return Err(DbError::parse(body.fields()).into()),
215 _ => return Err("unexpected message".into()),
216 };
217
218 scram.finish(body.data())?;
219 }
220 backend::Message::ErrorResponse(body) => {
221 return Err(DbError::parse(body.fields()).into());
222 }
223 _ => return Err("unsupported authentication".into()),
224 }
225 }
226 Ok(())
227 }
228
229 fn flush(&mut self) -> Result<(), Error> {
230 self.stream.write_all(&self.write_buf)?;
231 self.stream.flush()?;
232 self.write_buf.clear();
233 Ok(())
234 }
235
236 fn read_message(&mut self) -> Result<backend::Message, Error> {
237 loop {
238 if let Some(message) = backend::Message::parse(&mut self.read_buf)? {
239 if let backend::Message::NoticeResponse(body) = &message {
240 log::info!("postgres notice: {:?}", DbError::parse(body.fields()));
241 continue;
242 }
243 return Ok(message);
244 }
245 let mut buf = [0u8; 8192];
246 let n = self.stream.read(&mut buf)?;
247 if n == 0 {
248 return Err("unexpected EOF".into());
249 }
250 self.read_buf.extend_from_slice(&buf[..n]);
251 }
252 }
253
254 fn drain_ready(&mut self) -> Result<(), Error> {
255 loop {
256 match self.read_message()? {
257 backend::Message::ReadyForQuery(_) => return Ok(()),
258 backend::Message::ErrorResponse(body) => {
259 return Err(DbError::parse(body.fields()).into())
260 }
261 _ => {}
262 }
263 }
264 }
265
266 #[allow(clippy::type_complexity)]
267 fn prepare_query(
268 &mut self,
269 query: &str,
270 params_len: usize,
271 ) -> Result<(Vec<Type>, Vec<(String, Oid)>), Error> {
272 let param_oids = vec![0; params_len];
273 frontend::parse("", query, param_oids.iter().copied(), &mut self.write_buf)?;
274 frontend::describe(b'S', "", &mut self.write_buf)?;
275 frontend::sync(&mut self.write_buf);
276 self.flush()?;
277
278 let mut param_types = Vec::new();
279 let mut columns = Vec::new();
280 loop {
281 match self.read_message()? {
282 backend::Message::ParseComplete => {}
283 backend::Message::ParameterDescription(body) => {
284 let mut it = body.parameters();
285 while let Some(oid) = it.next()? {
286 let ty = Type::from_oid(oid).unwrap_or(Type::TEXT);
287 param_types.push(ty);
288 }
289 }
290 backend::Message::RowDescription(body) => {
291 let mut fields = body.fields();
292 while let Some(field) = fields.next()? {
293 columns.push((field.name().to_string(), field.type_oid()));
294 }
295 }
296 backend::Message::NoData => {}
297 backend::Message::ReadyForQuery(_) => break,
298 backend::Message::ErrorResponse(body) => {
299 let err = DbError::parse(body.fields());
300 self.drain_ready()?;
301 return Err(err.into());
302 }
303 _ => return Err("unexpected message".into()),
304 }
305 }
306
307 Ok((param_types, columns))
308 }
309
310 fn bind_execute<P, I>(
311 &mut self,
312 params: I,
313 param_types: &[Type],
314 mut rows: Option<&mut Vec<Vec<Option<Vec<u8>>>>>,
315 ) -> Result<u64, Error>
316 where
317 P: BorrowToSql,
318 I: IntoIterator<Item = P>,
319 I::IntoIter: ExactSizeIterator,
320 {
321 let params: Vec<P> = params.into_iter().collect();
322 assert_eq!(param_types.len(), params.len());
323 let param_formats: Vec<i16> = params
324 .iter()
325 .zip(param_types)
326 .map(|(p, t)| p.borrow_to_sql().encode_format(t) as i16)
327 .collect();
328
329 frontend::bind(
330 "",
331 "",
332 param_formats,
333 params.iter().zip(param_types.iter()),
334 |(param, ty), buf| match param.borrow_to_sql().to_sql_checked(ty, buf)? {
335 IsNull::No => Ok(postgres_protocol::IsNull::No),
336 IsNull::Yes => Ok(postgres_protocol::IsNull::Yes),
337 },
338 Some(1),
339 &mut self.write_buf,
340 )
341 .map_err(|e| match e {
342 frontend::BindError::Conversion(e) => e,
343 frontend::BindError::Serialization(e) => Box::new(e) as Error,
344 })?;
345 frontend::execute("", 0, &mut self.write_buf)?;
346 frontend::sync(&mut self.write_buf);
347 self.flush()?;
348
349 let mut rows_affected = 0;
350 loop {
351 match self.read_message()? {
352 backend::Message::BindComplete => {}
353 backend::Message::DataRow(body) => {
354 if let Some(out) = rows.as_mut() {
355 out.push(self.parse_data_row(body)?);
356 }
357 }
358 backend::Message::CommandComplete(body) => {
359 let tag = body.tag().map_err(|e| Box::new(e) as Error)?;
360 rows_affected = tag
361 .rsplit(' ')
362 .next()
363 .and_then(|s| s.parse().ok())
364 .unwrap_or(0);
365 }
366 backend::Message::EmptyQueryResponse => rows_affected = 0,
367 backend::Message::ReadyForQuery(_) => return Ok(rows_affected),
368 backend::Message::ErrorResponse(body) => {
369 let err = DbError::parse(body.fields());
370 self.drain_ready()?;
371 return Err(err.into());
372 }
373 _ => return Err("unexpected message".into()),
374 }
375 }
376 }
377
378 pub fn query_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter, Error>
379 where
380 P: BorrowToSql,
381 I: IntoIterator<Item = P>,
382 I::IntoIter: ExactSizeIterator,
383 {
384 let params = params.into_iter();
385 let (param_types, columns) = self.prepare_query(query, params.len())?;
386 let params: Vec<P> = params.collect();
387 let mut rows = Vec::new();
388 self.bind_execute(params, ¶m_types, Some(&mut rows))?;
389
390 Ok(RowIter {
391 columns,
392 rows: rows.into_iter(),
393 })
394 }
395
396 pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error> {
397 let (param_types, _) = self.prepare_query(query, params.len())?;
398 self.bind_execute(params.iter().copied(), ¶m_types, None)
399 }
400
401 pub fn query(
402 &mut self,
403 query: &str,
404 params: &[&(dyn ToSql + Sync)],
405 ) -> Result<Vec<Row>, Error> {
406 self.query_raw(query, params.iter().copied())?.collect()
407 }
408
409 pub fn query_one(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error> {
410 let mut it = self.query_raw(query, params.iter().copied())?;
411 let first = it.next()?.ok_or("no rows returned")?;
412 if it.next()?.is_some() {
413 return Err("more than one row returned".into());
414 }
415 Ok(first)
416 }
417
418 pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
419 frontend::query(query, &mut self.write_buf)?;
420 self.flush()?;
421
422 loop {
423 match self.read_message()? {
424 backend::Message::ReadyForQuery(_) => return Ok(()),
425 backend::Message::CommandComplete(_)
426 | backend::Message::EmptyQueryResponse
427 | backend::Message::RowDescription(_)
428 | backend::Message::DataRow(_) => {}
429 backend::Message::ErrorResponse(body) => {
430 let err = DbError::parse(body.fields());
431 self.drain_ready()?;
432 return Err(err.into());
433 }
434 _ => return Err("unexpected message".into()),
435 }
436 }
437 }
438
439 fn parse_data_row(&self, body: backend::DataRowBody) -> Result<Vec<Option<Vec<u8>>>, Error> {
440 let mut out = Vec::new();
441 let mut ranges = body.ranges();
442 let buf = body.buffer();
443 while let Some(range) = ranges.next()? {
444 match range {
445 Some(r) => out.push(Some(buf[r].to_vec())),
446 None => out.push(None),
447 }
448 }
449 Ok(out)
450 }
451}
452
453pub struct Row {
454 columns: Vec<(String, Oid)>,
455 values: Vec<Option<Vec<u8>>>,
456}
457
458pub trait RowIndex {
459 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize>;
460}
461
462impl RowIndex for usize {
463 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
464 if *self < columns.len() { Some(*self) } else { None }
465 }
466}
467
468impl RowIndex for &str {
469 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
470 columns.iter()
471 .position(|(name, _)| name == self)
472 .or_else(|| columns.iter()
473 .position(|(name, _)| name.eq_ignore_ascii_case(self)))
474 }
475}
476
477
478impl Row {
479 pub fn get<'a, T>(&'a self, idx: impl RowIndex) -> T
480 where
481 T: FromSql<'a>,
482 {
483 let idx = idx.idx(&self.columns).expect("invalid column");
484 let (_, oid) = &self.columns[idx];
485 let ty = Type::from_oid(*oid).unwrap_or(Type::TEXT);
486 let raw = self.values[idx].as_deref();
487 FromSql::from_sql_nullable(&ty, raw).unwrap()
488 }
489}
490
491pub struct RowIter {
492 columns: Vec<(String, Oid)>,
493 rows: std::vec::IntoIter<Vec<Option<Vec<u8>>>>,
494}
495
496impl FallibleIterator for RowIter {
497 type Item = Row;
498 type Error = Error;
499
500 fn next(&mut self) -> Result<Option<Row>, Error> {
501 Ok(self.rows.next().map(|values| Row {
502 columns: self.columns.clone(),
503 values,
504 }))
505 }
506}
507