serde_odbc/
string.rs

1/*
2This file is part of serde-odbc.
3
4serde-odbc is free software: you can redistribute it and/or modify
5it under the terms of the GNU Lesser General Public License as published by
6the Free Software Foundation, either version 3 of the License, or
7(at your option) any later version.
8
9serde-odbc is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU Lesser General Public License for more details.
13
14You should have received a copy of the GNU Lesser General Public License
15along with serde-odbc.  If not, see <http://www.gnu.org/licenses/>.
16*/
17use std::cmp::min;
18use std::mem::MaybeUninit;
19
20use generic_array::{ArrayLength, GenericArray};
21use odbc_sys::SQLLEN;
22use serde::ser::{Serialize, SerializeStruct, Serializer};
23
24use crate::binder::with_indicator;
25
26#[derive(Clone)]
27struct ByteArray<N: ArrayLength<u8>>(GenericArray<u8, N>);
28
29impl<N: Clone + ArrayLength<u8>> Copy for ByteArray<N> where N::ArrayType: Copy {}
30
31impl<N: ArrayLength<u8>> Serialize for ByteArray<N> {
32    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
33        serializer.serialize_bytes(self.0.as_slice())
34    }
35}
36
37#[derive(Clone)]
38pub struct String<N: ArrayLength<u8>> {
39    indicator: SQLLEN,
40    value: ByteArray<N>,
41}
42
43impl<N: Clone + ArrayLength<u8>> Copy for String<N> where N::ArrayType: Copy {}
44
45impl<N: ArrayLength<u8>> Serialize for String<N> {
46    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
47        let mut serializer = serializer.serialize_struct("String", 1)?;
48        with_indicator(&self.indicator as *const _ as *mut _, || {
49            serializer.serialize_field("value", &self.value)
50        })?;
51        serializer.end()
52    }
53}
54
55impl<N: ArrayLength<u8>> String<N> {
56    pub fn clear(&mut self) {
57        self.indicator = 0;
58    }
59
60    pub fn extend_from_slice(&mut self, value: &[u8]) {
61        let len = min(N::to_usize() - self.indicator as usize, value.len());
62
63        self.value.0.as_mut_slice()[self.indicator as usize..][..len]
64            .copy_from_slice(&value[..len]);
65
66        self.indicator += len as SQLLEN;
67    }
68
69    pub fn as_slice(&self) -> &[u8] {
70        &self.value.0.as_slice()[..self.indicator as usize]
71    }
72
73    pub fn as_mut_slice(&mut self) -> &mut [u8] {
74        &mut self.value.0.as_mut_slice()[..self.indicator as usize]
75    }
76}
77
78impl<N: ArrayLength<u8>> Default for String<N> {
79    fn default() -> Self {
80        Self {
81            indicator: 0,
82            value: unsafe {
83                #[allow(clippy::uninit_assumed_init)]
84                MaybeUninit::uninit().assume_init()
85            },
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    use generic_array::typenum::U8;
95
96    use crate::{
97        col_binding::Cols,
98        connection::{Connection, Environment},
99        param_binding::Params,
100        statement::Statement,
101        tests::CONN_STR,
102    };
103
104    #[test]
105    fn default_str() {
106        let value: String<U8> = Default::default();
107        assert_eq!(&b""[..], value.as_slice());
108    }
109
110    #[test]
111    fn make_str() {
112        let mut value: String<U8> = Default::default();
113        value.extend_from_slice(&b"foobar"[..]);
114        assert_eq!(&b"foobar"[..], value.as_slice());
115    }
116
117    #[test]
118    fn bind_str() {
119        let env = Environment::new().unwrap();
120        let conn = Connection::new(&env, CONN_STR).unwrap();
121
122        let mut stmt: Statement<Params<String<U8>>, Cols<String<U8>>> =
123            Statement::new(&conn, "SELECT ?").unwrap();
124        stmt.params().extend_from_slice(b"foobarfoobar");
125        stmt.exec().unwrap();
126        assert!(stmt.fetch().unwrap());
127        assert_eq!(&b"foobarfo"[..], stmt.cols().as_slice());
128        assert!(!stmt.fetch().unwrap());
129    }
130}