serde_odbc/
param_binding.rs1use std::mem::size_of;
18use std::ptr::null;
19
20use odbc_sys::{
21 SQLSetStmtAttr, SQLHSTMT, SQLPOINTER, SQL_ATTR_PARAMSET_SIZE, SQL_ATTR_PARAM_BIND_TYPE,
22};
23use serde::ser::Serialize;
24
25use super::error::{OdbcResult, Result};
26use super::param_binder::bind_params;
27
28pub trait ParamBinding {
29 fn new() -> Self;
30
31 type Params;
32 fn params(&mut self) -> &mut Self::Params;
33
34 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()>;
35}
36
37pub struct Params<P: Copy + Default + Serialize> {
38 data: P,
39 last_data: *const P,
40}
41
42pub struct NoParams {
43 data: (),
44}
45
46pub struct ParamSet<P: Copy + Serialize> {
47 data: Vec<P>,
48 last_data: *const P,
49 last_size: usize,
50}
51
52impl<P: Copy + Default + Serialize> ParamBinding for Params<P> {
53 fn new() -> Self {
54 Params {
55 data: Default::default(),
56 last_data: null(),
57 }
58 }
59
60 type Params = P;
61 fn params(&mut self) -> &mut Self::Params {
62 &mut self.data
63 }
64
65 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
66 let data = &self.data as *const P;
67
68 if self.last_data != data {
69 bind_params(stmt, &*data)?;
70 self.last_data = data;
71 }
72
73 Ok(())
74 }
75}
76
77impl ParamBinding for NoParams {
78 fn new() -> Self {
79 NoParams { data: () }
80 }
81
82 type Params = ();
83 fn params(&mut self) -> &mut Self::Params {
84 &mut self.data
85 }
86
87 unsafe fn bind(&mut self, _stmt: SQLHSTMT) -> Result<()> {
88 Ok(())
89 }
90}
91
92impl<P: Copy + Serialize> ParamBinding for ParamSet<P> {
93 fn new() -> Self {
94 ParamSet {
95 data: Vec::new(),
96 last_data: null(),
97 last_size: 0,
98 }
99 }
100
101 type Params = Vec<P>;
102 fn params(&mut self) -> &mut Self::Params {
103 &mut self.data
104 }
105
106 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
107 let data = self.data.first().unwrap() as *const P;
108 let size = self.data.len();
109
110 if self.last_data != data {
111 bind_params(stmt, &*data)?;
112 self.last_data = data;
113 }
114
115 if self.last_size != size {
116 Self::bind_param_set(stmt, size)?;
117 self.last_size = size;
118 }
119
120 Ok(())
121 }
122}
123
124impl<P: Copy + Serialize> ParamSet<P> {
125 unsafe fn bind_param_set(stmt: SQLHSTMT, size: usize) -> Result<()> {
126 SQLSetStmtAttr(
127 stmt,
128 SQL_ATTR_PARAM_BIND_TYPE,
129 size_of::<P>() as SQLPOINTER,
130 0,
131 )
132 .check()?;
133
134 SQLSetStmtAttr(stmt, SQL_ATTR_PARAMSET_SIZE, size as SQLPOINTER, 0).check()
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 use crate::{
143 col_binding::{Cols, NoCols},
144 connection::{Connection, Environment},
145 statement::Statement,
146 tests::CONN_STR,
147 };
148
149 #[test]
150 fn bind_param_set() {
151 let env = Environment::new().unwrap();
152 let conn = Connection::new(&env, CONN_STR).unwrap();
153
154 {
155 let mut stmt: Statement<NoParams, NoCols> =
156 Statement::new(&conn, "CREATE TEMPORARY TABLE tbl (col INTEGER NOT NULL)").unwrap();
157 stmt.exec().unwrap();
158 }
159
160 {
161 let mut stmt: Statement<ParamSet<i32>, NoCols> =
162 Statement::new(&conn, "INSERT INTO tbl (col) VALUES (?)").unwrap();
163 for i in 0..128 {
164 stmt.params().push(i);
165 }
166 stmt.exec().unwrap();
167 }
168
169 {
170 let mut stmt: Statement<NoParams, Cols<i32>> =
171 Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
172 stmt.exec().unwrap();
173 for i in 0..128 {
174 assert!(stmt.fetch().unwrap());
175 assert_eq!(i, *stmt.cols());
176 }
177 assert!(!stmt.fetch().unwrap());
178 }
179 }
180}