1use std::error::Error as StdError;
12use std::io::{Read, Write};
13use std::net::TcpStream;
14use std::time::Duration;
15
16use bytes::BytesMut;
17use fallible_iterator::FallibleIterator;
18use postgres_protocol::Oid;
19use postgres_protocol::authentication::{
20 md5_hash,
21 sasl::{self, ChannelBinding, ScramSha256},
22};
23use postgres_protocol::message::backend;
24use postgres_protocol::message::frontend;
25use postgres_types::{IsNull, Type};
26use socket2::{SockRef, TcpKeepalive};
27
28pub use fallible_iterator;
29pub use postgres_types::{BorrowToSql, FromSql, ToSql};
30pub use postgres_types as types;
31
32pub use crate::transaction::Transaction;
33pub use crate::config::Config;
34
35mod config;
36mod transaction;
37
38pub type Error = Box<dyn StdError + Send + Sync>;
39
40#[derive(Debug)]
41pub struct DbError {
42 severity: String,
43 code: String,
44 message: String,
45 detail: Option<String>,
46 hint: Option<String>,
47 position: Option<ErrorPosition>,
48}
49
50impl std::fmt::Display for DbError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}: {} ({})", self.severity, self.message, self.code)?;
53 if let Some(detail) = &self.detail {
54 write!(f, "\nDETAIL: {detail}")?;
55 }
56 if let Some(hint) = &self.hint {
57 write!(f, "\nHINT: {hint}")?;
58 }
59 if let Some(pos) = &self.position {
60 write!(f, "\nPOSITION: {pos:?}")?;
61 }
62 Ok(())
63 }
64}
65
66impl StdError for DbError {}
67
68#[derive(Debug)]
69pub enum ErrorPosition {
70 Original(u32),
71 Internal { position: u32, query: String },
72}
73
74#[derive(Debug, Clone, Copy)]
75pub struct NoTls;
76
77pub struct Client {
78 stream: TcpStream,
79 read_buf: BytesMut,
80 write_buf: BytesMut,
81}
82
83impl Client {
84 pub fn connect(s: &str, _tls: NoTls) -> Result<Client, Error> {
85 let config = config::Config::parse(s)?;
86 Self::connect_config(&config, _tls)
87 }
88
89 fn connect_config(config: &config::Config, _tls: NoTls) -> Result<Client, Error> {
90 let stream = TcpStream::connect((config.host.as_str(), config.port))?;
91
92 let sock_ref = SockRef::from(&stream);
93 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(50));
94 sock_ref.set_tcp_keepalive(&keepalive)?;
95
96 let user = &config.user;
97 let db = &config.db;
98
99 let mut this = Client {
100 stream,
101 read_buf: BytesMut::with_capacity(8192),
102 write_buf: BytesMut::with_capacity(8192),
103 };
104
105 let mut params: Vec<(&str, &str)> = Vec::new();
106 params.push(("user", user));
107 if !db.is_empty() {
108 params.push(("database", db));
109 }
110 params.push(("client_encoding", "UTF8"));
111
112 frontend::startup_message(params.iter().copied(), &mut this.write_buf)?;
113 this.flush()?;
114
115 this.handle_auth(user.as_bytes(), &config.password)?;
116
117 loop {
118 match this.read_message()? {
119 backend::Message::ReadyForQuery(_) => break,
120 backend::Message::BackendKeyData(_) => {}
121 backend::Message::ParameterStatus(_) => {}
122 backend::Message::ErrorResponse(body) => return Err(this.error_response(body).into()),
123 _ => return Err("unexpected message".into()),
124 }
125 }
126
127 Ok(this)
128 }
129
130 fn handle_auth(&mut self, user: &[u8], password: &str) -> Result<(), Error> {
131 loop {
132 match self.read_message()? {
133 backend::Message::AuthenticationOk => break,
134 backend::Message::AuthenticationCleartextPassword => {
135 frontend::password_message(password.as_bytes(), &mut self.write_buf)?;
137 self.flush()?;
138 }
139 backend::Message::AuthenticationMd5Password(body) => {
140 let output = md5_hash(user, password.as_bytes(), body.salt());
142 frontend::password_message(output.as_bytes(), &mut self.write_buf)?;
143 self.flush()?;
144 }
145 backend::Message::AuthenticationSasl(body) => {
146 let mut has_scram = false;
147 let mut mechs = body.mechanisms();
148 while let Some(mech) = mechs.next()? {
149 if mech == sasl::SCRAM_SHA_256 {
150 has_scram = true;
151 }
152 }
153 if !has_scram {
154 return Err("unsupported authentication".into());
155 }
156
157 let mut scram =
158 ScramSha256::new(password.as_bytes(), ChannelBinding::unsupported());
159
160 frontend::sasl_initial_response(
161 sasl::SCRAM_SHA_256,
162 scram.message(),
163 &mut self.write_buf,
164 )?;
165 self.flush()?;
166
167 let body = match self.read_message()? {
168 backend::Message::AuthenticationSaslContinue(body) => body,
169 backend::Message::ErrorResponse(body) => return Err(self.error_response(body).into()),
170 _ => return Err("unexpected message".into()),
171 };
172
173 scram.update(body.data())?;
174
175 frontend::sasl_response(scram.message(), &mut self.write_buf)?;
176 self.flush()?;
177
178 let body = match self.read_message()? {
179 backend::Message::AuthenticationSaslFinal(body) => body,
180 backend::Message::ErrorResponse(body) => return Err(self.error_response(body).into()),
181 _ => return Err("unexpected message".into()),
182 };
183
184 scram.finish(body.data())?;
185 }
186 backend::Message::ErrorResponse(body) => {
187 return Err(self.error_response(body).into());
188 }
189 _ => return Err("unsupported authentication".into()),
190 }
191 }
192 Ok(())
193 }
194
195 fn flush(&mut self) -> Result<(), Error> {
196 self.stream.write_all(&self.write_buf)?;
197 self.stream.flush()?;
198 self.write_buf.clear();
199 Ok(())
200 }
201
202 fn read_message(&mut self) -> Result<backend::Message, Error> {
203 loop {
204 if let Some(message) = backend::Message::parse(&mut self.read_buf)? {
205 return Ok(message);
206 }
207 let mut buf = [0u8; 8192];
208 let n = self.stream.read(&mut buf)?;
209 if n == 0 {
210 return Err("unexpected EOF".into());
211 }
212 self.read_buf.extend_from_slice(&buf[..n]);
213 }
214 }
215
216 fn error_response(&self, body: backend::ErrorResponseBody) -> DbError {
217 let mut severity = String::new();
218 let mut code = String::new();
219 let mut message = String::new();
220 let mut detail = None;
221 let mut hint = None;
222 let mut normal_position = None;
223 let mut internal_position = None;
224 let mut internal_query = None;
225 let mut fields = body.fields();
226 while let Some(field) = fields.next().unwrap() {
227 match field.type_() {
228 b'S' => severity = String::from_utf8_lossy(field.value_bytes()).into_owned(),
229 b'C' => code = String::from_utf8_lossy(field.value_bytes()).into_owned(),
230 b'M' => message = String::from_utf8_lossy(field.value_bytes()).into_owned(),
231 b'D' => detail = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
232 b'H' => hint = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
233 b'P' => normal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
234 b'p' => internal_position = String::from_utf8_lossy(field.value_bytes()).parse().ok(),
235 b'q' => internal_query = Some(String::from_utf8_lossy(field.value_bytes()).into_owned()),
236 _ => {}
237 }
238 }
239 let position = match normal_position {
240 Some(pos) => Some(ErrorPosition::Original(pos)),
241 None => internal_position.map(|pos| ErrorPosition::Internal {
242 position: pos,
243 query: internal_query.unwrap_or_default(),
244 }),
245 };
246 DbError { severity, code, message, detail, hint, position }
247 }
248
249 fn drain_ready(&mut self) -> Result<(), Error> {
250 loop {
251 match self.read_message()? {
252 backend::Message::ReadyForQuery(_) => return Ok(()),
253 backend::Message::ErrorResponse(body) => {
254 return Err(self.error_response(body).into())
255 }
256 _ => {}
257 }
258 }
259 }
260
261 #[allow(clippy::type_complexity)]
262 fn prepare_query(
263 &mut self,
264 query: &str,
265 params_len: usize,
266 ) -> Result<(Vec<Type>, Vec<(String, Oid)>), Error> {
267 let param_oids = vec![0; params_len];
268 frontend::parse("", query, param_oids.iter().copied(), &mut self.write_buf)?;
269 frontend::describe(b'S', "", &mut self.write_buf)?;
270 frontend::sync(&mut self.write_buf);
271 self.flush()?;
272
273 let mut param_types = Vec::new();
274 let mut columns = Vec::new();
275 loop {
276 match self.read_message()? {
277 backend::Message::ParseComplete => {}
278 backend::Message::ParameterDescription(body) => {
279 let mut it = body.parameters();
280 while let Some(oid) = it.next()? {
281 let ty = Type::from_oid(oid).unwrap_or(Type::TEXT);
282 param_types.push(ty);
283 }
284 }
285 backend::Message::RowDescription(body) => {
286 let mut fields = body.fields();
287 while let Some(field) = fields.next()? {
288 columns.push((field.name().to_string(), field.type_oid()));
289 }
290 }
291 backend::Message::NoData => {}
292 backend::Message::ReadyForQuery(_) => break,
293 backend::Message::ErrorResponse(body) => {
294 let err = self.error_response(body);
295 self.drain_ready()?;
296 return Err(err.into());
297 }
298 _ => return Err("unexpected message".into()),
299 }
300 }
301
302 Ok((param_types, columns))
303 }
304
305 fn bind_execute<P, I>(
306 &mut self,
307 params: I,
308 param_types: &[Type],
309 mut rows: Option<&mut Vec<Vec<Option<Vec<u8>>>>>,
310 ) -> Result<u64, Error>
311 where
312 P: BorrowToSql,
313 I: IntoIterator<Item = P>,
314 I::IntoIter: ExactSizeIterator,
315 {
316 let params: Vec<P> = params.into_iter().collect();
317 assert_eq!(param_types.len(), params.len());
318 let param_formats: Vec<i16> = params
319 .iter()
320 .zip(param_types)
321 .map(|(p, t)| p.borrow_to_sql().encode_format(t) as i16)
322 .collect();
323
324 frontend::bind(
325 "",
326 "",
327 param_formats,
328 params.iter().zip(param_types.iter()),
329 |(param, ty), buf| match param.borrow_to_sql().to_sql_checked(ty, buf)? {
330 IsNull::No => Ok(postgres_protocol::IsNull::No),
331 IsNull::Yes => Ok(postgres_protocol::IsNull::Yes),
332 },
333 Some(1),
334 &mut self.write_buf,
335 )
336 .map_err(|e| match e {
337 frontend::BindError::Conversion(e) => e,
338 frontend::BindError::Serialization(e) => Box::new(e) as Error,
339 })?;
340 frontend::execute("", 0, &mut self.write_buf)?;
341 frontend::sync(&mut self.write_buf);
342 self.flush()?;
343
344 let mut rows_affected = 0;
345 loop {
346 match self.read_message()? {
347 backend::Message::BindComplete => {}
348 backend::Message::DataRow(body) => {
349 if let Some(out) = rows.as_mut() {
350 out.push(self.parse_data_row(body)?);
351 }
352 }
353 backend::Message::CommandComplete(body) => {
354 let tag = body.tag().map_err(|e| Box::new(e) as Error)?;
355 rows_affected = tag
356 .rsplit(' ')
357 .next()
358 .and_then(|s| s.parse().ok())
359 .unwrap_or(0);
360 }
361 backend::Message::EmptyQueryResponse => rows_affected = 0,
362 backend::Message::ReadyForQuery(_) => return Ok(rows_affected),
363 backend::Message::ErrorResponse(body) => {
364 let err = self.error_response(body);
365 self.drain_ready()?;
366 return Err(err.into());
367 }
368 _ => return Err("unexpected message".into()),
369 }
370 }
371 }
372
373 pub fn query_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter, Error>
374 where
375 P: BorrowToSql,
376 I: IntoIterator<Item = P>,
377 I::IntoIter: ExactSizeIterator,
378 {
379 let params = params.into_iter();
380 let (param_types, columns) = self.prepare_query(query, params.len())?;
381 let params: Vec<P> = params.collect();
382 let mut rows = Vec::new();
383 self.bind_execute(params, ¶m_types, Some(&mut rows))?;
384
385 Ok(RowIter {
386 columns,
387 rows: rows.into_iter(),
388 })
389 }
390
391 pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error> {
392 let (param_types, _) = self.prepare_query(query, params.len())?;
393 self.bind_execute(params.iter().copied(), ¶m_types, None)
394 }
395
396 pub fn query(
397 &mut self,
398 query: &str,
399 params: &[&(dyn ToSql + Sync)],
400 ) -> Result<Vec<Row>, Error> {
401 self.query_raw(query, params.iter().copied())?.collect()
402 }
403
404 pub fn query_one(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error> {
405 let mut it = self.query_raw(query, params.iter().copied())?;
406 let first = it.next()?.ok_or("no rows returned")?;
407 if it.next()?.is_some() {
408 return Err("more than one row returned".into());
409 }
410 Ok(first)
411 }
412
413 pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
414 frontend::query(query, &mut self.write_buf)?;
415 self.flush()?;
416
417 loop {
418 match self.read_message()? {
419 backend::Message::ReadyForQuery(_) => return Ok(()),
420 backend::Message::CommandComplete(_)
421 | backend::Message::EmptyQueryResponse
422 | backend::Message::RowDescription(_)
423 | backend::Message::DataRow(_) => {}
424 backend::Message::ErrorResponse(body) => {
425 let err = self.error_response(body);
426 self.drain_ready()?;
427 return Err(err.into());
428 }
429 _ => return Err("unexpected message".into()),
430 }
431 }
432 }
433
434 fn parse_data_row(&self, body: backend::DataRowBody) -> Result<Vec<Option<Vec<u8>>>, Error> {
435 let mut out = Vec::new();
436 let mut ranges = body.ranges();
437 let buf = body.buffer();
438 while let Some(range) = ranges.next()? {
439 match range {
440 Some(r) => out.push(Some(buf[r].to_vec())),
441 None => out.push(None),
442 }
443 }
444 Ok(out)
445 }
446}
447
448pub struct Row {
449 columns: Vec<(String, Oid)>,
450 values: Vec<Option<Vec<u8>>>,
451}
452
453pub trait RowIndex {
454 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize>;
455}
456
457impl RowIndex for usize {
458 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
459 if *self < columns.len() { Some(*self) } else { None }
460 }
461}
462
463impl RowIndex for &str {
464 fn idx(&self, columns: &[(String, Oid)]) -> Option<usize> {
465 columns.iter()
466 .position(|(name, _)| name == self)
467 .or_else(|| columns.iter()
468 .position(|(name, _)| name.eq_ignore_ascii_case(self)))
469 }
470}
471
472
473impl Row {
474 pub fn get<'a, T>(&'a self, idx: impl RowIndex) -> T
475 where
476 T: FromSql<'a>,
477 {
478 let idx = idx.idx(&self.columns).expect("invalid column");
479 let (_, oid) = &self.columns[idx];
480 let ty = Type::from_oid(*oid).unwrap_or(Type::TEXT);
481 let raw = self.values[idx].as_deref();
482 FromSql::from_sql_nullable(&ty, raw).unwrap()
483 }
484}
485
486pub struct RowIter {
487 columns: Vec<(String, Oid)>,
488 rows: std::vec::IntoIter<Vec<Option<Vec<u8>>>>,
489}
490
491impl FallibleIterator for RowIter {
492 type Item = Row;
493 type Error = Error;
494
495 fn next(&mut self) -> Result<Option<Row>, Error> {
496 Ok(self.rows.next().map(|values| Row {
497 columns: self.columns.clone(),
498 values,
499 }))
500 }
501}
502