Skip to main content

we_trust_postgres/
connection.rs

1use crate::codec::{BackendMessage, PgCodec};
2use crate::message::Message;
3use crate::transaction::PostgresTransaction;
4use futures::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_util::codec::Framed;
7use tracing::{debug, info};
8use yykv_types::{ColumnInfo, DsError, DsResult, DsValue, EnumInfo, SchemaInspector, TableInfo};
9
10#[async_trait::async_trait]
11impl SchemaInspector for PostgresConnection {
12    async fn introspect(&self, schema: Option<&str>) -> DsResult<(Vec<TableInfo>, Vec<EnumInfo>)> {
13        let schema = schema.unwrap_or("public");
14
15        // 1. Fetch enums
16        let enum_sql = format!(
17            "SELECT 
18                t.typname as enum_name,
19                e.enumlabel as enum_variant
20             FROM pg_type t
21             JOIN pg_enum e ON t.oid = e.enumtypid
22             JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
23             WHERE n.nspname = '{}'
24             ORDER BY t.typname, e.enumsortorder",
25            schema
26        );
27
28        let enum_rows = self.query(&enum_sql, &[]).await?;
29        let mut enums = Vec::new();
30        let mut current_enum: Option<EnumInfo> = None;
31
32        for row in enum_rows {
33            if let DsValue::List(fields) = row {
34                let name = match fields.first() {
35                    Some(DsValue::Text(s)) => s.clone(),
36                    _ => continue,
37                };
38                let variant = match fields.get(1) {
39                    Some(DsValue::Text(s)) => s.clone(),
40                    _ => continue,
41                };
42
43                if let Some(ref mut e) = current_enum {
44                    if e.name == name {
45                        e.variants.push(variant);
46                        continue;
47                    } else {
48                        enums.push(current_enum.take().unwrap());
49                    }
50                }
51                current_enum = Some(EnumInfo {
52                    name,
53                    variants: vec![variant],
54                });
55            }
56        }
57        if let Some(e) = current_enum {
58            enums.push(e);
59        }
60
61        // 2. Fetch tables
62        let sql = format!(
63            "SELECT 
64                c.relname as table_name,
65                pg_catalog.obj_description(c.oid, 'pg_class') as table_comment
66             FROM pg_catalog.pg_class c
67             JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
68             WHERE n.nspname = '{}' AND c.relkind = 'r' AND c.relname NOT LIKE 'pg_%' AND c.relname NOT LIKE 'sql_%'",
69            schema
70        );
71
72        let table_rows = self.query(&sql, &[]).await?;
73        let mut tables = Vec::new();
74
75        for row in table_rows {
76            if let DsValue::List(fields) = row {
77                let table_name = match fields.first() {
78                    Some(DsValue::Text(s)) => s.clone(),
79                    _ => continue,
80                };
81                let description = match fields.get(1) {
82                    Some(DsValue::Text(s)) => Some(s.clone()),
83                    _ => None,
84                };
85
86                let mut columns = Vec::new();
87                let col_sql = format!(
88                    "SELECT 
89                        column_name, 
90                        data_type, 
91                        is_nullable,
92                        column_default,
93                        (SELECT pg_catalog.col_description(c.oid, a.attnum)
94                         FROM pg_catalog.pg_class c
95                         JOIN pg_catalog.pg_attribute a ON c.oid = a.attrelid
96                         WHERE c.relname = '{}' AND a.attname = column_name) as column_comment
97                     FROM information_schema.columns 
98                     WHERE table_schema = '{}' AND table_name = '{}'
99                     ORDER BY ordinal_position",
100                    table_name, schema, table_name
101                );
102
103                let col_rows = self.query(&col_sql, &[]).await?;
104                for col_row in col_rows {
105                    if let DsValue::List(cfields) = col_row {
106                        let name = match cfields.first() {
107                            Some(DsValue::Text(s)) => s.clone(),
108                            _ => continue,
109                        };
110                        let data_type = match cfields.get(1) {
111                            Some(DsValue::Text(s)) => s.clone(),
112                            _ => continue,
113                        };
114                        let is_nullable = match cfields.get(2) {
115                            Some(DsValue::Text(s)) => s == "YES",
116                            _ => false,
117                        };
118                        let default = match cfields.get(3) {
119                            Some(DsValue::Text(s)) => Some(s.clone()),
120                            _ => None,
121                        };
122                        let col_description = match cfields.get(4) {
123                            Some(DsValue::Text(s)) => Some(s.clone()),
124                            _ => None,
125                        };
126
127                        columns.push(ColumnInfo {
128                            name,
129                            data_type,
130                            is_nullable,
131                            is_primary_key: false, // Will fill later
132                            is_enum: false,        // Will fill later
133                            foreign_key: None,
134                            default,
135                            description: col_description,
136                        });
137                    }
138                }
139
140                tables.push(TableInfo {
141                    name: table_name,
142                    columns,
143                    description,
144                });
145            }
146        }
147
148        Ok((tables, enums))
149    }
150}
151
152type Result<T> = std::result::Result<T, DsError>;
153
154pub struct PostgresConnection {
155    pub url: String,
156}
157
158struct ConnectionConfig {
159    host: String,
160    port: u16,
161    user: String,
162    password: Option<String>,
163    database: String,
164}
165
166impl PostgresConnection {
167    pub fn new(url: String) -> Self {
168        Self { url }
169    }
170
171    fn parse_url(&self) -> Result<ConnectionConfig> {
172        let url = self
173            .url
174            .strip_prefix("postgres://")
175            .or_else(|| self.url.strip_prefix("postgresql://"))
176            .ok_or_else(|| DsError::protocol("Invalid postgres URL scheme"))?;
177
178        let (auth_host, database) = match url.find('/') {
179            Some(i) => (&url[..i], url[i + 1..].to_string()),
180            None => (url, "postgres".to_string()),
181        };
182
183        let (auth, host_port) = match auth_host.find('@') {
184            Some(i) => (&auth_host[..i], &auth_host[i + 1..]),
185            None => ("postgres", auth_host),
186        };
187
188        let (user, password) = match auth.find(':') {
189            Some(i) => (auth[..i].to_string(), Some(auth[i + 1..].to_string())),
190            None => (auth.to_string(), None),
191        };
192
193        let (host, port) = match host_port.find(':') {
194            Some(i) => (
195                host_port[..i].to_string(),
196                host_port[i + 1..]
197                    .parse()
198                    .map_err(|_| DsError::protocol("Invalid port"))?,
199            ),
200            None => (host_port.to_string(), 5432),
201        };
202
203        Ok(ConnectionConfig {
204            host,
205            port,
206            user,
207            password,
208            database,
209        })
210    }
211
212    async fn connect(&self) -> Result<Framed<TcpStream, PgCodec>> {
213        let config = self.parse_url()?;
214        let addr = format!("{}:{}", config.host, config.port);
215        let stream = TcpStream::connect(addr)
216            .await
217            .map_err(|e| DsError::protocol(format!("Failed to connect to postgres: {}", e)))?;
218
219        let mut framed = Framed::new(stream, PgCodec::new());
220
221        // 1. Startup
222        let params = vec![
223            ("user".to_string(), config.user.clone()),
224            ("database".to_string(), config.database.clone()),
225            ("client_encoding".to_string(), "UTF8".to_string()),
226        ];
227
228        framed.send(Message::Startup { params }).await?;
229
230        // 2. Authentication & Ready
231        loop {
232            let msg = framed
233                .next()
234                .await
235                .ok_or_else(|| DsError::protocol("Connection closed during startup"))??;
236
237            match msg {
238                BackendMessage::AuthenticationOk => {
239                    debug!("Postgres authentication OK");
240                }
241                BackendMessage::AuthenticationCleartextPassword => {
242                    let pass = config.password.as_ref().ok_or_else(|| {
243                        DsError::protocol("Password required by server but not provided")
244                    })?;
245                    framed.send(Message::Password(pass.clone())).await?;
246                }
247                BackendMessage::AuthenticationMD5Password { salt } => {
248                    let user = config.user.clone();
249                    let pass = config.password.as_ref().ok_or_else(|| {
250                        DsError::protocol("Password required by server but not provided")
251                    })?;
252
253                    // md5(md5(password + username) + salt)
254                    let hash1 = md5::compute(format!("{}{}", pass, user));
255                    let hash1_hex = hex::encode(hash1.0);
256
257                    let mut hash2_input = Vec::new();
258                    hash2_input.extend_from_slice(hash1_hex.as_bytes());
259                    hash2_input.extend_from_slice(&salt);
260
261                    let hash2 = md5::compute(hash2_input);
262                    let response = format!("md5{}", hex::encode(hash2.0));
263
264                    framed.send(Message::Password(response)).await?;
265                }
266                BackendMessage::ParameterStatus { name, value } => {
267                    debug!("Postgres parameter status: {} = {}", name, value);
268                }
269                BackendMessage::BackendKeyData { .. } => {}
270                BackendMessage::ReadyForQuery { .. } => {
271                    break;
272                }
273                BackendMessage::ErrorResponse { fields } => {
274                    let msg = fields
275                        .iter()
276                        .find(|(t, _)| *t == b'M')
277                        .map(|(_, m)| m.clone())
278                        .unwrap_or_default();
279                    return Err(DsError::protocol(format!(
280                        "Postgres startup error: {}",
281                        msg
282                    )));
283                }
284                _ => {
285                    debug!("Received unexpected message during startup: {:?}", msg);
286                }
287            }
288        }
289
290        Ok(framed)
291    }
292
293    pub async fn execute(&self, sql: &str, params: &[DsValue]) -> Result<u64> {
294        info!("Postgres executing: {}", sql);
295        let mut framed = self.connect().await?;
296
297        if params.is_empty() {
298            // Use Simple Query for no parameters
299            framed.send(Message::Query(sql.to_string())).await?;
300        } else {
301            // Use Extended Query for parameters
302            framed
303                .send(Message::Parse {
304                    name: "".to_string(),
305                    query: sql.to_string(),
306                    param_types: vec![0; params.len()], // Ask Postgres to infer types for all parameters
307                })
308                .await?;
309
310            framed
311                .send(Message::Bind {
312                    portal: "".to_string(),
313                    statement: "".to_string(),
314                    params: params.to_vec(),
315                })
316                .await?;
317
318            framed
319                .send(Message::Execute {
320                    portal: "".to_string(),
321                    max_rows: 0,
322                })
323                .await?;
324
325            framed.send(Message::Sync).await?;
326        }
327
328        let mut affected_rows = 0;
329        loop {
330            let msg = framed
331                .next()
332                .await
333                .ok_or_else(|| DsError::protocol("Connection closed during execute"))??;
334
335            match msg {
336                BackendMessage::CommandComplete { tag } => {
337                    if let Some(s) = tag.split_whitespace().last() {
338                        affected_rows = s.parse().unwrap_or(0);
339                    }
340                }
341                BackendMessage::ReadyForQuery { .. } => break,
342                BackendMessage::ErrorResponse { fields } => {
343                    let msg = fields
344                        .iter()
345                        .find(|(t, _)| *t == b'M')
346                        .map(|(_, m)| m.clone())
347                        .unwrap_or_default();
348                    return Err(DsError::protocol(format!(
349                        "Postgres execute error: {}",
350                        msg
351                    )));
352                }
353                _ => {}
354            }
355        }
356
357        Ok(affected_rows)
358    }
359
360    pub async fn query(&self, sql: &str, params: &[DsValue]) -> Result<Vec<DsValue>> {
361        info!("Postgres querying: {}", sql);
362
363        let mut framed = self.connect().await?;
364
365        if params.is_empty() {
366            // Use Simple Query for no parameters
367            framed.send(Message::Query(sql.to_string())).await?;
368        } else {
369            // Use Extended Query for parameters
370            framed
371                .send(Message::Parse {
372                    name: "".to_string(),
373                    query: sql.to_string(),
374                    param_types: vec![0; params.len()], // Ask Postgres to infer types for all parameters
375                })
376                .await?;
377
378            framed
379                .send(Message::Bind {
380                    portal: "".to_string(),
381                    statement: "".to_string(),
382                    params: params.to_vec(),
383                })
384                .await?;
385
386            framed
387                .send(Message::Execute {
388                    portal: "".to_string(),
389                    max_rows: 0,
390                })
391                .await?;
392
393            framed.send(Message::Sync).await?;
394        }
395
396        let mut results = Vec::new();
397        loop {
398            let msg = framed
399                .next()
400                .await
401                .ok_or_else(|| DsError::protocol("Connection closed during query"))??;
402
403            match msg {
404                BackendMessage::RowDescription { .. } => {}
405                BackendMessage::DataRow { values } => {
406                    debug!("Postgres received row with {} columns", values.len());
407                    let mut row = Vec::with_capacity(values.len());
408                    for val in values {
409                        match val {
410                            Some(bytes) => {
411                                let s = String::from_utf8_lossy(&bytes).to_string();
412                                row.push(DsValue::Text(s));
413                            }
414                            None => row.push(DsValue::Null),
415                        }
416                    }
417                    results.push(DsValue::List(row));
418                }
419                BackendMessage::ReadyForQuery { .. } => break,
420                BackendMessage::ErrorResponse { fields } => {
421                    let msg = fields
422                        .iter()
423                        .find(|(t, _)| *t == b'M')
424                        .map(|(_, m)| m.clone())
425                        .unwrap_or_default();
426                    return Err(DsError::protocol(format!("Postgres query error: {}", msg)));
427                }
428                _ => {}
429            }
430        }
431
432        Ok(results)
433    }
434
435    pub async fn begin_transaction(self) -> Result<PostgresTransaction> {
436        Ok(PostgresTransaction::new(self))
437    }
438}