serde_odbc/
param_binding.rs

1/*
2This file is part of serde-odbc.
3
4serde-odbc is free software: you can redistribute it and/or modify
5it under the terms of the GNU Lesser General Public License as published by
6the Free Software Foundation, either version 3 of the License, or
7(at your option) any later version.
8
9serde-odbc is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU Lesser General Public License for more details.
13
14You should have received a copy of the GNU Lesser General Public License
15along with serde-odbc.  If not, see <http://www.gnu.org/licenses/>.
16*/
17use std::mem::size_of;
18use std::ptr::null;
19
20use odbc_sys::{
21    SQLSetStmtAttr, SQLHSTMT, SQLPOINTER, SQL_ATTR_PARAMSET_SIZE, SQL_ATTR_PARAM_BIND_TYPE,
22};
23use serde::ser::Serialize;
24
25use super::error::{OdbcResult, Result};
26use super::param_binder::bind_params;
27
28pub trait ParamBinding {
29    fn new() -> Self;
30
31    type Params;
32    fn params(&mut self) -> &mut Self::Params;
33
34    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()>;
35}
36
37pub struct Params<P: Copy + Default + Serialize> {
38    data: P,
39    last_data: *const P,
40}
41
42pub struct NoParams {
43    data: (),
44}
45
46pub struct ParamSet<P: Copy + Serialize> {
47    data: Vec<P>,
48    last_data: *const P,
49    last_size: usize,
50}
51
52impl<P: Copy + Default + Serialize> ParamBinding for Params<P> {
53    fn new() -> Self {
54        Params {
55            data: Default::default(),
56            last_data: null(),
57        }
58    }
59
60    type Params = P;
61    fn params(&mut self) -> &mut Self::Params {
62        &mut self.data
63    }
64
65    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
66        let data = &self.data as *const P;
67
68        if self.last_data != data {
69            bind_params(stmt, &*data)?;
70            self.last_data = data;
71        }
72
73        Ok(())
74    }
75}
76
77impl ParamBinding for NoParams {
78    fn new() -> Self {
79        NoParams { data: () }
80    }
81
82    type Params = ();
83    fn params(&mut self) -> &mut Self::Params {
84        &mut self.data
85    }
86
87    unsafe fn bind(&mut self, _stmt: SQLHSTMT) -> Result<()> {
88        Ok(())
89    }
90}
91
92impl<P: Copy + Serialize> ParamBinding for ParamSet<P> {
93    fn new() -> Self {
94        ParamSet {
95            data: Vec::new(),
96            last_data: null(),
97            last_size: 0,
98        }
99    }
100
101    type Params = Vec<P>;
102    fn params(&mut self) -> &mut Self::Params {
103        &mut self.data
104    }
105
106    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
107        let data = self.data.first().unwrap() as *const P;
108        let size = self.data.len();
109
110        if self.last_data != data {
111            bind_params(stmt, &*data)?;
112            self.last_data = data;
113        }
114
115        if self.last_size != size {
116            Self::bind_param_set(stmt, size)?;
117            self.last_size = size;
118        }
119
120        Ok(())
121    }
122}
123
124impl<P: Copy + Serialize> ParamSet<P> {
125    unsafe fn bind_param_set(stmt: SQLHSTMT, size: usize) -> Result<()> {
126        SQLSetStmtAttr(
127            stmt,
128            SQL_ATTR_PARAM_BIND_TYPE,
129            size_of::<P>() as SQLPOINTER,
130            0,
131        )
132        .check()?;
133
134        SQLSetStmtAttr(stmt, SQL_ATTR_PARAMSET_SIZE, size as SQLPOINTER, 0).check()
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    use crate::{
143        col_binding::{Cols, NoCols},
144        connection::{Connection, Environment},
145        statement::Statement,
146        tests::CONN_STR,
147    };
148
149    #[test]
150    fn bind_param_set() {
151        let env = Environment::new().unwrap();
152        let conn = Connection::new(&env, CONN_STR).unwrap();
153
154        {
155            let mut stmt: Statement<NoParams, NoCols> =
156                Statement::new(&conn, "CREATE TEMPORARY TABLE tbl (col INTEGER NOT NULL)").unwrap();
157            stmt.exec().unwrap();
158        }
159
160        {
161            let mut stmt: Statement<ParamSet<i32>, NoCols> =
162                Statement::new(&conn, "INSERT INTO tbl (col) VALUES (?)").unwrap();
163            for i in 0..128 {
164                stmt.params().push(i);
165            }
166            stmt.exec().unwrap();
167        }
168
169        {
170            let mut stmt: Statement<NoParams, Cols<i32>> =
171                Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
172            stmt.exec().unwrap();
173            for i in 0..128 {
174                assert!(stmt.fetch().unwrap());
175                assert_eq!(i, *stmt.cols());
176            }
177            assert!(!stmt.fetch().unwrap());
178        }
179    }
180}