pgwire_lite/
connection.rs

1// src/connection.rs
2
3use std::collections::HashMap;
4use std::ffi::{c_void, CStr};
5use std::sync::Arc;
6use std::time::Instant;
7
8use log::debug;
9
10use libpq::Connection;
11use libpq_sys::ExecStatusType::{PGRES_COMMAND_OK, PGRES_TUPLES_OK};
12use libpq_sys::{
13    PGContextVisibility, PQclear, PQconsumeInput, PQfname, PQgetResult, PQgetvalue, PQlibVersion,
14    PQnfields, PQntuples, PQresultStatus, PQresultVerboseErrorMessage, PQsendQuery,
15    PQsetErrorVerbosity, PQsetNoticeReceiver,
16};
17
18use crate::notices::{notice_receiver, Notice, NoticeStorage, Verbosity};
19use crate::value::Value;
20
21/// Main client for interacting with PostgreSQL-compatible servers.
22///
23/// This struct provides the core functionality for establishing connections
24/// and executing queries against a PostgreSQL-compatible server.
25pub struct PgwireLite {
26    hostname: String,
27    port: u16,
28    use_tls: bool,
29    verbosity: Verbosity,
30    notices: NoticeStorage,
31}
32
33/// Contains the complete result of a query execution.
34///
35/// This struct provides access to all aspects of a query result,
36/// including rows, columns, notices, and execution statistics.
37#[derive(Debug)]
38pub struct QueryResult {
39    /// Rows returned by the query, represented as maps of column names to values.
40    pub rows: Vec<HashMap<String, Value>>,
41
42    /// Names of the columns in the result set.
43    pub column_names: Vec<String>,
44
45    /// Notices generated during query execution.
46    pub notices: Vec<Notice>,
47
48    /// Number of rows in the result set.
49    pub row_count: i32,
50
51    /// Number of columns in the result set.
52    pub col_count: i32,
53
54    /// Number of notices generated during query execution.
55    pub notice_count: usize,
56
57    /// Status of the query execution.
58    pub status: libpq_sys::ExecStatusType,
59
60    /// Elapsed time for the query execution in milliseconds.
61    pub elapsed_time_ms: u64,
62}
63
64// Helper function to safely clear a PGresult and log it
65fn clear_pg_result(result: *mut libpq_sys::PGresult) {
66    if !result.is_null() {
67        unsafe {
68            debug!("Clearing PGresult at {:p}", result);
69            PQclear(result);
70            debug!("PGresult cleared successfully");
71        }
72    }
73}
74
75impl PgwireLite {
76    /// Creates a new PgwireLite client with the specified connection parameters.
77    ///
78    /// # Arguments
79    ///
80    /// * `hostname` - The hostname or IP address of the PostgreSQL server
81    /// * `port` - The port number the PostgreSQL server is listening on
82    /// * `use_tls` - Whether to use TLS encryption for the connection
83    /// * `verbosity` - Error/notice verbosity level, one of: "terse", "default", "verbose", "sqlstate"
84    ///
85    /// # Returns
86    ///
87    /// A Result containing the new PgwireLite instance or an error
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// use pgwire_lite::PgwireLite;
93    ///
94    /// let client = PgwireLite::new("localhost", 5432, false, "default")
95    ///     .expect("Failed to create client");
96    /// ```
97    pub fn new(
98        hostname: &str,
99        port: u16,
100        use_tls: bool,
101        verbosity: &str,
102    ) -> Result<Self, Box<dyn std::error::Error>> {
103        let verbosity_val = match verbosity.to_lowercase().as_str() {
104            "default" => Verbosity::Default,
105            "verbose" => Verbosity::Verbose,
106            "terse" => Verbosity::Terse,
107            "sqlstate" => Verbosity::Sqlstate,
108            "" => Verbosity::Default,
109            _ => Verbosity::Default,
110        };
111
112        // Set the log filter level based on verbosity
113        match verbosity_val {
114            Verbosity::Terse => log::set_max_level(log::LevelFilter::Warn),
115            Verbosity::Default => log::set_max_level(log::LevelFilter::Info),
116            Verbosity::Verbose => log::set_max_level(log::LevelFilter::Debug),
117            Verbosity::Sqlstate => log::set_max_level(log::LevelFilter::Debug),
118        }
119
120        let notices = Arc::new(std::sync::Mutex::new(Vec::new()));
121
122        Ok(PgwireLite {
123            hostname: hostname.to_string(),
124            port,
125            use_tls,
126            verbosity: verbosity_val,
127            notices,
128        })
129    }
130
131    /// Returns the version of the underlying libpq library.
132    ///
133    /// # Returns
134    ///
135    /// A string representing the version in the format "major.minor.patch"
136    pub fn libpq_version(&self) -> String {
137        let version = unsafe { PQlibVersion() };
138        let major = version / 10000;
139        let minor = (version / 100) % 100;
140        let patch = version % 100;
141        format!("{}.{}.{}", major, minor, patch)
142    }
143
144    /// Returns the current verbosity setting.
145    ///
146    /// # Returns
147    ///
148    /// A string representation of the current verbosity level
149    pub fn verbosity(&self) -> String {
150        format!("{:?}", self.verbosity)
151    }
152
153    // Helper method to consume any pending results
154    fn consume_pending_results(conn: &Connection) {
155        debug!("Consuming pending results");
156        unsafe {
157            // First make sure we've read all data available from the server
158            PQconsumeInput(conn.into());
159
160            // Then clear any pending results
161            loop {
162                let result = PQgetResult(conn.into());
163                if result.is_null() {
164                    break;
165                }
166                clear_pg_result(result);
167            }
168        }
169    }
170
171    /// Executes a SQL query and returns the results.
172    ///
173    /// This method creates a fresh connection for each query, executes the query,
174    /// and processes the results. It handles all aspects of connection management
175    /// and error handling.
176    ///
177    /// # Arguments
178    ///
179    /// * `query` - The SQL query to execute
180    ///
181    /// # Returns
182    ///
183    /// A Result containing a QueryResult with the query results or an error
184    ///
185    /// # Example
186    ///
187    /// ```
188    /// use pgwire_lite::PgwireLite;
189    ///
190    /// let client = PgwireLite::new("localhost", 5444, false, "default")
191    ///     .expect("Failed to create client");
192    ///     
193    /// let result = client.query("SELECT 1 as value")
194    ///     .expect("Query failed");
195    ///     
196    /// println!("Number of rows: {}", result.row_count);
197    /// ```
198    pub fn query(&self, query: &str) -> Result<QueryResult, Box<dyn std::error::Error>> {
199        // Clear any previous notices
200        debug!("Clearing previous notices");
201        if let Ok(mut notices) = self.notices.lock() {
202            notices.clear();
203        }
204
205        let start_time = Instant::now();
206
207        // Create a connection string
208        let conn_str = format!(
209            "host={} port={} sslmode={} application_name=pgwire-lite-client connect_timeout=10 client_encoding=UTF8",
210            self.hostname,
211            self.port,
212            if self.use_tls { "verify-full" } else { "disable" }
213        );
214        debug!("Establishing connection using: {}", conn_str);
215
216        // Create a fresh connection for this query
217        let conn = Connection::new(&conn_str)?;
218
219        // Connection diagnostics
220        unsafe {
221            let ssl_in_use = libpq_sys::PQsslInUse((&conn).into()) != 0;
222            let host_ptr = libpq_sys::PQhost((&conn).into());
223            let port_ptr = libpq_sys::PQport((&conn).into());
224            if !host_ptr.is_null() && !port_ptr.is_null() {
225                let host = CStr::from_ptr(host_ptr).to_string_lossy();
226                let port = CStr::from_ptr(port_ptr).to_string_lossy();
227                debug!("Connected to: {}:{} (ssl: {})", host, port, ssl_in_use);
228            }
229
230            // PQstatus output
231            let status = libpq_sys::PQstatus((&conn).into());
232            debug!("Connection status: {:?}", status);
233
234            // PQtransactionStatus output
235            let tx_status = libpq_sys::PQtransactionStatus((&conn).into());
236            debug!("Transaction status: {:?}", tx_status);
237
238            // PQserverVersion output
239            let server_version = libpq_sys::PQserverVersion((&conn).into());
240            let major = server_version / 10000;
241            let minor = (server_version / 100) % 100;
242            let revision = server_version % 100;
243            debug!(
244                "Server version: {}.{}.{} ({})",
245                major, minor, revision, server_version
246            );
247        }
248
249        // Apply the desired verbosity level
250        debug!("Setting error verbosity to: {:?}", self.verbosity);
251        unsafe {
252            PQsetErrorVerbosity((&conn).into(), self.verbosity.into());
253        }
254
255        // Set up notice receiver for the connection
256        debug!("Setting up notice receiver");
257        let notices_ptr = Arc::into_raw(self.notices.clone()) as *mut c_void;
258        unsafe {
259            PQsetNoticeReceiver((&conn).into(), Some(notice_receiver), notices_ptr);
260        }
261
262        // add ; to `query` if it doesn't end with one
263        let query = if query.ends_with(';') {
264            query.to_string()
265        } else {
266            format!("{};", query)
267        };
268
269        // Use PQsendQuery
270        debug!("Sending query: {}", query);
271        let send_success = unsafe { PQsendQuery((&conn).into(), query.as_ptr() as *const i8) };
272        if send_success == 0 {
273            // If send failed, return the error
274            return Err(
275                format!("Error: {}", conn.error_message().unwrap_or("Unknown error")).into(),
276            );
277        }
278
279        // Process the result
280        debug!("Processing the result");
281        let result = unsafe { PQgetResult((&conn).into()) };
282
283        if result.is_null() {
284            return Err("No result returned".into());
285        }
286
287        let status = unsafe { PQresultStatus(result) };
288
289        if status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK {
290            // Try to get a detailed error message
291            let error_msg_ptr = unsafe {
292                PQresultVerboseErrorMessage(
293                    result,
294                    self.verbosity.into(),
295                    PGContextVisibility::PQSHOW_CONTEXT_ALWAYS,
296                )
297            };
298
299            let error_msg = if !error_msg_ptr.is_null() {
300                // Convert the C string to a Rust string
301                let msg = unsafe { CStr::from_ptr(error_msg_ptr).to_string_lossy().into_owned() };
302                // Free the C string allocated by PQresultVerboseErrorMessage
303                unsafe { libpq_sys::PQfreemem(error_msg_ptr as *mut _) };
304                msg
305            } else {
306                // Fallback to the standard connection error message if verbose message is not available
307                conn.error_message().unwrap_or("Unknown error").to_string()
308            };
309
310            clear_pg_result(result);
311
312            // Clear any pending results
313            Self::consume_pending_results(&conn);
314
315            // return Err(format!("{}", error_msg.trim_end()).into());
316            return Err(error_msg.trim_end().to_string().into());
317        }
318
319        // Get column information
320        debug!("Getting column count");
321        let col_count = unsafe { PQnfields(result) };
322
323        // Create a vector to store column names
324        debug!("Getting column names");
325        let mut column_names = Vec::with_capacity(col_count as usize);
326        for col_index in 0..col_count {
327            let col_name_ptr = unsafe { PQfname(result, col_index) };
328            if !col_name_ptr.is_null() {
329                let col_name =
330                    unsafe { CStr::from_ptr(col_name_ptr).to_string_lossy().into_owned() };
331                column_names.push(col_name);
332            } else {
333                column_names.push(String::from("(unknown)"));
334            }
335        }
336
337        // Initialize row_count here
338        debug!("Getting row count");
339        let row_count = if status == PGRES_TUPLES_OK {
340            unsafe { PQntuples(result) }
341        } else {
342            0
343        };
344
345        // Create the rows vector
346        let mut rows = Vec::new();
347
348        // Get row data if available
349        if status == PGRES_TUPLES_OK {
350            debug!("Processing rows");
351
352            // Process each row
353            for row_index in 0..row_count {
354                let mut row_data = HashMap::new();
355
356                // Process each column in the row
357                for col_index in 0..col_count {
358                    let value_ptr = unsafe { PQgetvalue(result, row_index, col_index) };
359                    let value = if !value_ptr.is_null() {
360                        let string_value =
361                            unsafe { CStr::from_ptr(value_ptr).to_string_lossy().into_owned() };
362                        Value::String(string_value)
363                    } else {
364                        Value::Null
365                    };
366
367                    // Insert value into the row map using the column name as key
368                    row_data.insert(column_names[col_index as usize].clone(), value);
369                }
370
371                rows.push(row_data);
372            }
373        }
374        debug!("Rows processed: {}", rows.len());
375
376        clear_pg_result(result);
377
378        // Check for any remaining results and clear them
379        Self::consume_pending_results(&conn);
380
381        // Get the notices that were collected during the query
382        debug!("Collecting notices");
383        let notices = if let Ok(mut lock) = self.notices.lock() {
384            lock.drain(..).collect()
385        } else {
386            Vec::new()
387        };
388        let notice_count = notices.len();
389
390        let elapsed_time_ms = start_time.elapsed().as_millis() as u64;
391
392        drop(conn);
393
394        Ok(QueryResult {
395            rows,
396            column_names,
397            notices,
398            row_count,
399            col_count,
400            notice_count,
401            status,
402            elapsed_time_ms,
403        })
404    }
405}