serde_odbc/
statement.rs

1/*
2This file is part of serde-odbc.
3
4serde-odbc is free software: you can redistribute it and/or modify
5it under the terms of the GNU Lesser General Public License as published by
6the Free Software Foundation, either version 3 of the License, or
7(at your option) any later version.
8
9serde-odbc is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU Lesser General Public License for more details.
13
14You should have received a copy of the GNU Lesser General Public License
15along with serde-odbc.  If not, see <http://www.gnu.org/licenses/>.
16*/
17use std::ptr::null_mut;
18
19use odbc_sys::{
20    SQLAllocHandle, SQLExecute, SQLFetch, SQLFreeHandle, SQLFreeStmt, SQLPrepare, SQLHANDLE,
21    SQLHSTMT, SQLINTEGER, SQL_CLOSE, SQL_HANDLE_STMT, SQL_NO_DATA,
22};
23use serde::ser::Serialize;
24
25use super::col_binding::{ColBinding, RowSet};
26use super::connection::Connection;
27use super::error::{OdbcResult, Result};
28use super::param_binding::ParamBinding;
29
30pub struct Statement<P: ParamBinding, C: ColBinding> {
31    stmt: SQLHSTMT,
32    is_positioned: bool,
33    params: P,
34    cols: C,
35}
36
37impl<P: ParamBinding, C: ColBinding> Statement<P, C> {
38    pub fn new(conn: &Connection, stmt_str: &str) -> Result<Self> {
39        let mut stmt: SQLHANDLE = null_mut();
40
41        unsafe { SQLAllocHandle(SQL_HANDLE_STMT, conn.handle(), &mut stmt) }.check()?;
42
43        let stmt = stmt as SQLHSTMT;
44
45        unsafe { SQLPrepare(stmt, stmt_str.as_ptr(), stmt_str.len() as SQLINTEGER) }.check()?;
46
47        Ok(Statement {
48            stmt,
49            is_positioned: false,
50            params: P::new(),
51            cols: C::new(),
52        })
53    }
54
55    pub fn handle(&self) -> SQLHANDLE {
56        self.stmt as SQLHANDLE
57    }
58
59    pub fn params(&mut self) -> &mut P::Params {
60        self.params.params()
61    }
62
63    pub fn cols(&self) -> &C::Cols {
64        self.cols.cols()
65    }
66
67    pub fn exec(&mut self) -> Result<()> {
68        if self.is_positioned {
69            unsafe { SQLFreeStmt(self.stmt, SQL_CLOSE) }.check()?;
70
71            self.is_positioned = false;
72        }
73
74        unsafe {
75            self.params.bind(self.stmt)?;
76            self.cols.bind(self.stmt)?;
77        }
78
79        unsafe { SQLExecute(self.stmt) }.check()
80    }
81
82    pub fn fetch(&mut self) -> Result<bool> {
83        let rc = unsafe { SQLFetch(self.stmt) };
84
85        rc.check()?;
86
87        self.is_positioned = true;
88
89        Ok(rc != SQL_NO_DATA && self.cols.fetch())
90    }
91}
92
93impl<P: ParamBinding, C: Default + Copy + Serialize> Statement<P, RowSet<C>> {
94    pub fn with_fetch_size(conn: &Connection, stmt_str: &str, fetch_size: usize) -> Result<Self> {
95        let mut stmt = Self::new(conn, stmt_str)?;
96
97        stmt.set_fetch_size(fetch_size);
98
99        Ok(stmt)
100    }
101
102    pub fn fetch_size(&self) -> usize {
103        self.cols.fetch_size()
104    }
105
106    pub fn set_fetch_size(&mut self, size: usize) {
107        self.cols.set_fetch_size(size)
108    }
109}
110
111impl<P: ParamBinding, C: ColBinding> Drop for Statement<P, C> {
112    fn drop(&mut self) {
113        let _ = unsafe { SQLFreeHandle(SQL_HANDLE_STMT, self.handle()) };
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    use crate::{
122        col_binding::Cols, connection::Environment, param_binding::Params, tests::CONN_STR,
123    };
124
125    #[test]
126    fn exec_stmt() {
127        let env = Environment::new().unwrap();
128        let conn = Connection::new(&env, CONN_STR).unwrap();
129
130        let mut stmt: Statement<Params<i32>, Cols<i32>> =
131            Statement::new(&conn, "SELECT ?").unwrap();
132        *stmt.params() = 42;
133        stmt.exec().unwrap();
134        assert!(stmt.fetch().unwrap());
135        assert_eq!(42, *stmt.cols());
136        assert!(!stmt.fetch().unwrap());
137    }
138}