rsfbclient_rust/
client.rs

1//! `FirebirdConnection` implementation for the pure rust firebird client
2
3use bytes::{BufMut, Bytes, BytesMut};
4use std::{
5    env,
6    io::{Read, Write},
7    net::TcpStream,
8};
9
10use crate::{
11    arc4::*,
12    blr,
13    consts::{AuthPluginType, ProtocolVersion, WireOp},
14    srp::*,
15    util::*,
16    wire::*,
17    xsqlda::{parse_xsqlda, xsqlda_to_blr, PrepareInfo, XSqlVar, XSQLDA_DESCRIBE_VARS},
18};
19use rsfbclient_core::*;
20
21type RustDbHandle = DbHandle;
22type RustTrHandle = TrHandle;
23type RustStmtHandle = StmtHandle;
24
25/// Firebird client implemented in pure rust
26pub struct RustFbClient {
27    conn: Option<FirebirdWireConnection>,
28    charset: Charset,
29}
30
31/// Required configuration for an attachment with the pure rust client
32#[derive(Default, Clone)]
33pub struct RustFbClientAttachmentConfig {
34    pub host: String,
35    pub port: u16,
36    pub db_name: String,
37    pub user: String,
38    pub pass: String,
39    pub role_name: Option<String>,
40}
41
42/// A Connection to a firebird server
43pub struct FirebirdWireConnection {
44    /// Connection socket
45    socket: FbStream,
46
47    /// Wire protocol version
48    pub(crate) version: ProtocolVersion,
49
50    /// Buffer to read the network data
51    buff: Box<[u8]>,
52
53    /// Lazy responses to read
54    lazy_count: u32,
55
56    pub(crate) charset: Charset,
57}
58
59/// Data to keep track about a prepared statement
60pub struct StmtHandleData {
61    /// Statement handle
62    handle: RustStmtHandle,
63    /// Output xsqlda
64    xsqlda: Vec<XSqlVar>,
65    /// Blr representation of the above
66    blr: Bytes,
67    /// Number of parameters
68    param_count: usize,
69}
70
71impl RustFbClient {
72    ///Construct a new instance of the pure rust client
73    pub fn new(charset: Charset) -> Self {
74        Self {
75            conn: None,
76            charset,
77        }
78    }
79}
80
81impl FirebirdClientDbOps for RustFbClient {
82    type DbHandle = RustDbHandle;
83    type AttachmentConfig = RustFbClientAttachmentConfig;
84
85    fn attach_database(
86        &mut self,
87        config: &Self::AttachmentConfig,
88        dialect: Dialect,
89        no_db_triggers: bool,
90    ) -> Result<RustDbHandle, FbError> {
91        let host = config.host.as_str();
92        let port = config.port;
93        let db_name = config.db_name.as_str();
94        let user = config.user.as_str();
95        let pass = config.pass.as_str();
96        let role = match &config.role_name {
97            Some(ro) => Some(ro.as_str()),
98            None => None,
99        };
100
101        // Take the existing connection, or connects
102        let mut conn = match self.conn.take() {
103            Some(conn) => conn,
104            None => FirebirdWireConnection::connect(
105                host,
106                port,
107                db_name,
108                user,
109                pass,
110                self.charset.clone(),
111            )?,
112        };
113
114        let attach_result =
115            conn.attach_database(db_name, user, pass, role, dialect, no_db_triggers);
116
117        // Put the connection back
118        self.conn.replace(conn);
119
120        attach_result
121    }
122
123    fn detach_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
124        self.conn
125            .as_mut()
126            .map(|conn| conn.detach_database(db_handle))
127            .unwrap_or_else(err_client_not_connected)
128    }
129
130    fn drop_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
131        self.conn
132            .as_mut()
133            .map(|conn| conn.drop_database(db_handle))
134            .unwrap_or_else(err_client_not_connected)
135    }
136
137    fn create_database(
138        &mut self,
139        config: &Self::AttachmentConfig,
140        page_size: Option<u32>,
141        dialect: Dialect,
142    ) -> Result<RustDbHandle, FbError> {
143        let host = config.host.as_str();
144        let port = config.port;
145        let db_name = config.db_name.as_str();
146        let user = config.user.as_str();
147        let pass = config.pass.as_str();
148        let role = match &config.role_name {
149            Some(ro) => Some(ro.as_str()),
150            None => None,
151        };
152
153        // Take the existing connection, or connects
154        let mut conn = match self.conn.take() {
155            Some(conn) => conn,
156            None => FirebirdWireConnection::connect(
157                host,
158                port,
159                db_name,
160                user,
161                pass,
162                self.charset.clone(),
163            )?,
164        };
165
166        let attach_result = conn.create_database(db_name, user, pass, page_size, role, dialect);
167
168        // Put the connection back
169        self.conn.replace(conn);
170
171        attach_result
172    }
173}
174
175impl FirebirdClientSqlOps for RustFbClient {
176    type DbHandle = RustDbHandle;
177    type TrHandle = RustTrHandle;
178    type StmtHandle = StmtHandleData;
179
180    fn begin_transaction(
181        &mut self,
182        db_handle: &mut Self::DbHandle,
183        confs: TransactionConfiguration,
184    ) -> Result<Self::TrHandle, FbError> {
185        self.conn
186            .as_mut()
187            .map(|conn| conn.begin_transaction(db_handle, confs))
188            .unwrap_or_else(err_client_not_connected)
189    }
190
191    fn transaction_operation(
192        &mut self,
193        tr_handle: &mut Self::TrHandle,
194        op: TrOp,
195    ) -> Result<(), FbError> {
196        self.conn
197            .as_mut()
198            .map(|conn| conn.transaction_operation(tr_handle, op))
199            .unwrap_or_else(err_client_not_connected)
200    }
201
202    fn exec_immediate(
203        &mut self,
204        _db_handle: &mut Self::DbHandle,
205        tr_handle: &mut Self::TrHandle,
206        dialect: Dialect,
207        sql: &str,
208    ) -> Result<(), FbError> {
209        self.conn
210            .as_mut()
211            .map(|conn| conn.exec_immediate(tr_handle, dialect, sql))
212            .unwrap_or_else(err_client_not_connected)
213    }
214
215    fn prepare_statement(
216        &mut self,
217        db_handle: &mut Self::DbHandle,
218        tr_handle: &mut Self::TrHandle,
219        dialect: Dialect,
220        sql: &str,
221    ) -> Result<(StmtType, Self::StmtHandle), FbError> {
222        self.conn
223            .as_mut()
224            .map(|conn| conn.prepare_statement(db_handle, tr_handle, dialect, sql))
225            .unwrap_or_else(err_client_not_connected)
226    }
227
228    fn free_statement(
229        &mut self,
230        stmt_handle: &mut Self::StmtHandle,
231        op: FreeStmtOp,
232    ) -> Result<(), FbError> {
233        self.conn
234            .as_mut()
235            .map(|conn| conn.free_statement(stmt_handle, op))
236            .unwrap_or_else(err_client_not_connected)
237    }
238
239    fn execute(
240        &mut self,
241        _db_handle: &mut Self::DbHandle,
242        tr_handle: &mut Self::TrHandle,
243        stmt_handle: &mut Self::StmtHandle,
244        params: Vec<SqlType>,
245    ) -> Result<usize, FbError> {
246        self.conn
247            .as_mut()
248            .map(|conn| conn.execute(tr_handle, stmt_handle, &params))
249            .unwrap_or_else(err_client_not_connected)
250    }
251
252    fn execute2(
253        &mut self,
254        _db_handle: &mut Self::DbHandle,
255        tr_handle: &mut Self::TrHandle,
256        stmt_handle: &mut Self::StmtHandle,
257        params: Vec<SqlType>,
258    ) -> Result<Vec<Column>, FbError> {
259        self.conn
260            .as_mut()
261            .map(|conn| conn.execute2(tr_handle, stmt_handle, &params))
262            .unwrap_or_else(err_client_not_connected)
263    }
264
265    fn fetch(
266        &mut self,
267        _db_handle: &mut Self::DbHandle,
268        tr_handle: &mut Self::TrHandle,
269        stmt_handle: &mut Self::StmtHandle,
270    ) -> Result<Option<Vec<Column>>, FbError> {
271        self.conn
272            .as_mut()
273            .map(|conn| conn.fetch(tr_handle, stmt_handle))
274            .unwrap_or_else(err_client_not_connected)
275    }
276}
277
278fn err_client_not_connected<T>() -> Result<T, FbError> {
279    Err("Client not connected to the server, call `attach_database` to connect".into())
280}
281
282impl FirebirdWireConnection {
283    /// Start a connection to the firebird server
284    pub fn connect(
285        host: &str,
286        port: u16,
287        db_name: &str,
288        user: &str,
289        pass: &str,
290        charset: Charset,
291    ) -> Result<Self, FbError> {
292        let socket = TcpStream::connect((host, port))?;
293
294        // System username
295        let username =
296            env::var("USER").unwrap_or_else(|_| env::var("USERNAME").unwrap_or_default());
297        let hostname = socket
298            .local_addr()
299            .map(|addr| addr.to_string())
300            .unwrap_or_default();
301
302        let mut socket = FbStream::Plain(socket);
303
304        // Random key for the srp
305        let srp_key: [u8; 32] = rand::random();
306
307        let req = connect(db_name, user, &username, &hostname, &srp_key);
308        socket.write_all(&req)?;
309        socket.flush()?;
310
311        // May be a bit too much
312        let mut buff = vec![0; BUFFER_LENGTH as usize * 2].into_boxed_slice();
313
314        let len = socket.read(&mut buff)?;
315        let mut resp = Bytes::copy_from_slice(&buff[..len]);
316
317        let ConnectionResponse {
318            version,
319            auth_plugin,
320        } = parse_accept(&mut resp)?;
321
322        if let Some(mut auth_plugin) = auth_plugin {
323            loop {
324                match auth_plugin.kind {
325                    plugin @ AuthPluginType::Srp => {
326                        let srp = SrpClient::<sha1::Sha1>::new(&srp_key, &SRP_GROUP);
327
328                        if let Some(data) = auth_plugin.data {
329                            socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
330
331                            // Authentication Ok
332                            break;
333                        } else {
334                            // Server requested a different authentication method than the client specified
335                            // in the initial connection
336
337                            socket.write_all(&cont_auth(
338                                hex::encode(srp.get_a_pub()).as_bytes(),
339                                plugin,
340                                AuthPluginType::plugin_list(),
341                                &[],
342                            ))?;
343                            socket.flush()?;
344
345                            let len = socket.read(&mut buff)?;
346                            let mut resp = Bytes::copy_from_slice(&buff[..len]);
347
348                            auth_plugin = parse_cont_auth(&mut resp)?;
349                        }
350                    }
351                    plugin @ AuthPluginType::Srp256 => {
352                        let srp = SrpClient::<sha2::Sha256>::new(&srp_key, &SRP_GROUP);
353
354                        if let Some(data) = auth_plugin.data {
355                            socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
356
357                            // Authentication Ok
358                            break;
359                        } else {
360                            // Server requested a different authentication method than the client specified
361                            // in the initial connection
362
363                            socket.write_all(&cont_auth(
364                                hex::encode(srp.get_a_pub()).as_bytes(),
365                                plugin,
366                                AuthPluginType::plugin_list(),
367                                &[],
368                            ))?;
369                            socket.flush()?;
370
371                            let len = socket.read(&mut buff)?;
372                            let mut resp = Bytes::copy_from_slice(&buff[..len]);
373
374                            auth_plugin = parse_cont_auth(&mut resp)?;
375                        }
376                    }
377                }
378            }
379        }
380
381        Ok(Self {
382            socket,
383            version,
384            buff,
385            lazy_count: 0,
386            charset,
387        })
388    }
389
390    /// Create the database and attach, returning a database handle
391    pub fn create_database(
392        &mut self,
393        db_name: &str,
394        user: &str,
395        pass: &str,
396        page_size: Option<u32>,
397        role_name: Option<&str>,
398        dialect: Dialect,
399    ) -> Result<DbHandle, FbError> {
400        self.socket.write_all(&create(
401            db_name,
402            user,
403            pass,
404            self.version,
405            self.charset.clone(),
406            page_size,
407            role_name.clone(),
408            dialect,
409        ))?;
410        self.socket.flush()?;
411
412        let resp = self.read_response()?;
413
414        Ok(DbHandle(resp.handle))
415    }
416
417    /// Connect to a database, returning a database handle
418    pub fn attach_database(
419        &mut self,
420        db_name: &str,
421        user: &str,
422        pass: &str,
423        role_name: Option<&str>,
424        dialect: Dialect,
425        no_db_triggers: bool,
426    ) -> Result<DbHandle, FbError> {
427        self.socket.write_all(&attach(
428            db_name,
429            user,
430            pass,
431            self.version,
432            self.charset.clone(),
433            role_name.clone(),
434            dialect,
435            no_db_triggers,
436        ))?;
437        self.socket.flush()?;
438
439        let resp = self.read_response()?;
440
441        Ok(DbHandle(resp.handle))
442    }
443
444    /// Disconnect from the database
445    pub fn detach_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
446        self.socket.write_all(&detach(db_handle.0))?;
447        self.socket.flush()?;
448
449        self.read_response()?;
450
451        Ok(())
452    }
453
454    /// Drop the database
455    pub fn drop_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
456        self.socket.write_all(&drop_database(db_handle.0))?;
457        self.socket.flush()?;
458
459        self.read_response()?;
460
461        Ok(())
462    }
463
464    /// Start a new transaction, with the specified transaction parameter buffer
465    pub fn begin_transaction(
466        &mut self,
467        db_handle: &mut DbHandle,
468        confs: TransactionConfiguration,
469    ) -> Result<TrHandle, FbError> {
470        let mut tpb = vec![
471            ibase::isc_tpb_version3 as u8,
472            confs.isolation.into(),
473            confs.data_access as u8,
474            confs.lock_resolution.into(),
475        ];
476        if let TrLockResolution::Wait(Some(time)) = confs.lock_resolution {
477            tpb.push(ibase::isc_tpb_lock_timeout as u8);
478            tpb.push(4 as u8);
479            tpb.extend_from_slice(&time.to_le_bytes());
480        }
481
482        if let TrIsolationLevel::ReadCommited(rec) = confs.isolation {
483            tpb.push(rec as u8);
484        }
485
486        self.socket.write_all(&transaction(db_handle.0, &tpb))?;
487        self.socket.flush()?;
488
489        let resp = self.read_response()?;
490
491        Ok(TrHandle(resp.handle))
492    }
493
494    /// Commit / Rollback a transaction
495    pub fn transaction_operation(
496        &mut self,
497        tr_handle: &mut TrHandle,
498        op: TrOp,
499    ) -> Result<(), FbError> {
500        self.socket
501            .write_all(&transaction_operation(tr_handle.0, op))?;
502        self.socket.flush()?;
503
504        self.read_response()?;
505
506        Ok(())
507    }
508
509    /// Execute a sql immediately, without returning rows
510    pub fn exec_immediate(
511        &mut self,
512        tr_handle: &mut TrHandle,
513        dialect: Dialect,
514        sql: &str,
515    ) -> Result<(), FbError> {
516        self.socket.write_all(&exec_immediate(
517            tr_handle.0,
518            dialect as u32,
519            sql,
520            &self.charset,
521        )?)?;
522        self.socket.flush()?;
523
524        self.read_response()?;
525
526        Ok(())
527    }
528
529    /// Alloc and prepare a statement
530    ///
531    /// Returns the statement type, handle and xsqlda describing the columns
532    pub fn prepare_statement(
533        &mut self,
534        db_handle: &mut DbHandle,
535        tr_handle: &mut TrHandle,
536        dialect: Dialect,
537        sql: &str,
538    ) -> Result<(StmtType, StmtHandleData), FbError> {
539        // Alloc statement
540        self.socket.write_all(&allocate_statement(db_handle.0))?;
541        // Prepare statement
542        self.socket.write_all(&prepare_statement(
543            tr_handle.0,
544            u32::MAX,
545            dialect as u32,
546            sql,
547            &self.charset,
548        )?)?;
549        self.socket.flush()?;
550
551        let (mut op_code, mut resp) = self.read_packet()?;
552
553        // Read lazy responses
554        for _ in 0..self.lazy_count {
555            if op_code != WireOp::Response as u32 {
556                return err_conn_rejected(op_code);
557            }
558            self.lazy_count -= 1;
559            parse_response(&mut resp)?;
560
561            op_code = resp.get_u32()?;
562        }
563
564        // Alloc resp
565        if op_code != WireOp::Response as u32 {
566            return err_conn_rejected(op_code);
567        }
568
569        let stmt_handle = StmtHandle(parse_response(&mut resp)?.handle);
570
571        // Prepare resp
572        let op_code = resp.get_u32()?;
573
574        if op_code != WireOp::Response as u32 {
575            return err_conn_rejected(op_code);
576        }
577
578        let mut xsqlda = Vec::new();
579
580        let mut resp = parse_response(&mut resp)?;
581        let PrepareInfo {
582            stmt_type,
583            mut param_count,
584            mut truncated,
585        } = parse_xsqlda(&mut resp.data, &mut xsqlda)?;
586
587        while truncated {
588            // Get more info on the types
589            let next_index = (xsqlda.len() as u16).to_le_bytes();
590
591            self.socket.write_all(&info_sql(
592                stmt_handle.0,
593                &[
594                    &[
595                        ibase::isc_info_sql_sqlda_start as u8, // Describe a xsqlda
596                        2,
597                        next_index[0], // Index, first byte
598                        next_index[1], // Index, second byte
599                    ],
600                    &XSQLDA_DESCRIBE_VARS[..], // Data to be returned
601                ]
602                .concat(),
603            ))?;
604            self.socket.flush()?;
605
606            let mut data = self.read_response()?.data;
607
608            let parse_resp = parse_xsqlda(&mut data, &mut xsqlda)?;
609            truncated = parse_resp.truncated;
610            param_count = parse_resp.param_count;
611        }
612
613        // Coerce the output columns and transform to blr
614        for var in xsqlda.iter_mut() {
615            var.coerce()?;
616        }
617        let blr = xsqlda_to_blr(&xsqlda)?;
618
619        Ok((
620            stmt_type,
621            StmtHandleData {
622                handle: stmt_handle,
623                xsqlda,
624                blr,
625                param_count,
626            },
627        ))
628    }
629
630    /// Closes or drops a statement
631    pub fn free_statement(
632        &mut self,
633        stmt_handle: &mut StmtHandleData,
634        op: FreeStmtOp,
635    ) -> Result<(), FbError> {
636        self.socket
637            .write_all(&free_statement(stmt_handle.handle.0, op))?;
638        // Obs.: Lazy response
639
640        self.lazy_count += 1;
641
642        Ok(())
643    }
644
645    /// Execute the prepared statement with parameters
646    pub fn execute(
647        &mut self,
648        tr_handle: &mut TrHandle,
649        stmt_handle: &mut StmtHandleData,
650        params: &[SqlType],
651    ) -> Result<usize, FbError> {
652        if params.len() != stmt_handle.param_count {
653            return Err(format!(
654                "Tried to execute a statement that has {} parameters while providing {}",
655                stmt_handle.param_count,
656                params.len()
657            )
658            .into());
659        }
660
661        // Execute
662        let params = blr::params_to_blr(self, tr_handle, params)?;
663
664        self.socket.write_all(&execute(
665            tr_handle.0,
666            stmt_handle.handle.0,
667            &params.blr,
668            &params.values,
669        ))?;
670        self.socket.flush()?;
671
672        self.read_response()?;
673
674        // Get affected rows
675        self.socket.write_all(&info_sql(
676            stmt_handle.handle.0,
677            &[ibase::isc_info_sql_records as u8], // Request affected rows,
678        ))?;
679        self.socket.flush()?;
680
681        let mut data = self.read_response()?.data;
682
683        parse_info_sql_affected_rows(&mut data)
684    }
685
686    /// Execute the prepared statement with parameters, returning data
687    pub fn execute2(
688        &mut self,
689        tr_handle: &mut TrHandle,
690        stmt_handle: &mut StmtHandleData,
691        params: &[SqlType],
692    ) -> Result<Vec<Column>, FbError> {
693        if params.len() != stmt_handle.param_count {
694            return Err(format!(
695                "Tried to execute a statement that has {} parameters while providing {}",
696                stmt_handle.param_count,
697                params.len()
698            )
699            .into());
700        }
701
702        let params = blr::params_to_blr(self, tr_handle, params)?;
703
704        self.socket.write_all(&execute2(
705            tr_handle.0,
706            stmt_handle.handle.0,
707            &params.blr,
708            &params.values,
709            &stmt_handle.blr,
710        ))?;
711        self.socket.flush()?;
712
713        let (mut op_code, mut resp) = read_packet(&mut self.socket, &mut self.buff)?;
714
715        // Read lazy responses
716        for _ in 0..self.lazy_count {
717            if op_code != WireOp::Response as u32 {
718                return err_conn_rejected(op_code);
719            }
720            self.lazy_count -= 1;
721            parse_response(&mut resp)?;
722
723            op_code = resp.get_u32()?;
724        }
725
726        if op_code == WireOp::Response as u32 {
727            // An error ocurred
728            parse_response(&mut resp)?;
729        }
730
731        if op_code != WireOp::SqlResponse as u32 {
732            return err_conn_rejected(op_code);
733        }
734
735        let parsed_cols =
736            parse_sql_response(&mut resp, &stmt_handle.xsqlda, self.version, &self.charset)?;
737
738        parse_response(&mut resp)?;
739
740        let mut cols = Vec::with_capacity(parsed_cols.len());
741
742        for pc in parsed_cols {
743            cols.push(pc.into_column(self, tr_handle)?);
744        }
745
746        Ok(cols)
747    }
748
749    /// Fetch rows from the executed statement, coercing the types
750    /// according to the provided blr
751    pub fn fetch(
752        &mut self,
753        tr_handle: &mut TrHandle,
754        stmt_handle: &mut StmtHandleData,
755    ) -> Result<Option<Vec<Column>>, FbError> {
756        self.socket
757            .write_all(&fetch(stmt_handle.handle.0, &stmt_handle.blr))?;
758        self.socket.flush()?;
759
760        let (mut op_code, mut resp) = read_packet(&mut self.socket, &mut self.buff)?;
761
762        // Read lazy responses
763        for _ in 0..self.lazy_count {
764            if op_code != WireOp::Response as u32 {
765                return err_conn_rejected(op_code);
766            }
767            self.lazy_count -= 1;
768            parse_response(&mut resp)?;
769
770            op_code = resp.get_u32()?;
771        }
772
773        if op_code == WireOp::Response as u32 {
774            // An error ocurred
775            parse_response(&mut resp)?;
776        }
777
778        if op_code != WireOp::FetchResponse as u32 {
779            return err_conn_rejected(op_code);
780        }
781
782        if let Some(parsed_cols) =
783            parse_fetch_response(&mut resp, &stmt_handle.xsqlda, self.version, &self.charset)?
784        {
785            let mut cols = Vec::with_capacity(parsed_cols.len());
786
787            for pc in parsed_cols {
788                cols.push(pc.into_column(self, tr_handle)?);
789            }
790
791            Ok(Some(cols))
792        } else {
793            Ok(None)
794        }
795    }
796
797    /// Create a new blob, returning the blob handle and id
798    pub fn create_blob(
799        &mut self,
800        tr_handle: &mut TrHandle,
801    ) -> Result<(BlobHandle, BlobId), FbError> {
802        self.socket.write_all(&create_blob(tr_handle.0))?;
803        self.socket.flush()?;
804
805        let resp = self.read_response()?;
806
807        Ok((BlobHandle(resp.handle), BlobId(resp.object_id)))
808    }
809
810    /// Put blob segments
811    pub fn put_segments(&mut self, blob_handle: BlobHandle, data: &[u8]) -> Result<(), FbError> {
812        for segment in data.chunks(crate::blr::MAX_DATA_LENGTH) {
813            self.socket
814                .write_all(&put_segment(blob_handle.0, segment))?;
815            self.socket.flush()?;
816
817            self.read_response()?;
818        }
819
820        Ok(())
821    }
822
823    /// Open a blob, returning the blob handle
824    pub fn open_blob(
825        &mut self,
826        tr_handle: &mut TrHandle,
827        blob_id: BlobId,
828    ) -> Result<BlobHandle, FbError> {
829        self.socket.write_all(&open_blob(tr_handle.0, blob_id.0))?;
830        self.socket.flush()?;
831
832        let resp = self.read_response()?;
833
834        Ok(BlobHandle(resp.handle))
835    }
836
837    /// Get a blob segment, returns the bytes and true if there is more data
838    pub fn get_segment(&mut self, blob_handle: BlobHandle) -> Result<(Bytes, bool), FbError> {
839        self.socket.write_all(&get_segment(blob_handle.0))?;
840        self.socket.flush()?;
841
842        let mut blob_data = BytesMut::with_capacity(256);
843
844        let resp = self.read_response()?;
845        let mut data = resp.data;
846
847        loop {
848            if data.remaining() < 2 {
849                break;
850            }
851            let len = data.get_u16_le()? as usize;
852            if data.remaining() < len {
853                return err_invalid_response();
854            }
855            blob_data.put_slice(&data[..len]);
856            data.advance(len)?;
857        }
858
859        Ok((blob_data.freeze(), resp.handle == 2))
860    }
861
862    /// Closes a blob handle
863    pub fn close_blob(&mut self, blob_handle: BlobHandle) -> Result<(), FbError> {
864        self.socket.write_all(&close_blob(blob_handle.0))?;
865        self.socket.flush()?;
866
867        self.read_response()?;
868
869        Ok(())
870    }
871
872    /// Read a server response
873    fn read_response(&mut self) -> Result<Response, FbError> {
874        read_response(&mut self.socket, &mut self.buff, &mut self.lazy_count)
875    }
876
877    /// Reads a packet from the socket
878    fn read_packet(&mut self) -> Result<(u32, Bytes), FbError> {
879        read_packet(&mut self.socket, &mut self.buff)
880    }
881}
882
883/// Read a server response
884fn read_response(
885    socket: &mut impl Read,
886    buff: &mut [u8],
887    lazy_count: &mut u32,
888) -> Result<Response, FbError> {
889    let (mut op_code, mut resp) = read_packet(socket, buff)?;
890
891    // Read lazy responses
892    for _ in 0..*lazy_count {
893        if op_code != WireOp::Response as u32 {
894            return err_conn_rejected(op_code);
895        }
896        *lazy_count -= 1;
897        parse_response(&mut resp)?;
898
899        op_code = resp.get_u32()?;
900    }
901
902    if op_code != WireOp::Response as u32 {
903        return err_conn_rejected(op_code);
904    }
905
906    parse_response(&mut resp)
907}
908
909/// Reads a packet from the socket
910fn read_packet(socket: &mut impl Read, buff: &mut [u8]) -> Result<(u32, Bytes), FbError> {
911    let mut len = socket.read(buff)?;
912    let mut resp = BytesMut::from(&buff[..len]);
913
914    loop {
915        if len == buff.len() {
916            // The buffer was not large enough, so read more
917            len = socket.read(buff)?;
918            resp.put_slice(&buff[..len]);
919        } else {
920            break;
921        }
922    }
923    let mut resp = resp.freeze();
924
925    let op_code = loop {
926        let op_code = resp.get_u32()?;
927
928        if op_code != WireOp::Dummy as u32 {
929            break op_code;
930        }
931    };
932
933    Ok((op_code, resp))
934}
935
936/// Performs the srp authentication with the server, returning the encrypted stream
937fn srp_auth<D>(
938    mut socket: FbStream,
939    buff: &mut [u8],
940    srp: SrpClient<D>,
941    plugin: AuthPluginType,
942    user: &str,
943    pass: &str,
944    data: SrpAuthData,
945) -> Result<FbStream, FbError>
946where
947    D: digest::Digest,
948{
949    // Generate a private key with the salt received from the server
950    let private_key = srp_private_key::<sha1::Sha1>(user.as_bytes(), pass.as_bytes(), &data.salt);
951
952    // Generate a verified with the private key above and the server public key received
953    let verifier = srp
954        .process_reply(user.as_bytes(), &data.salt, &private_key, &data.pub_key)
955        .map_err(|e| FbError::from(format!("Srp error: {}", e)))?;
956
957    // Generate a proof to send to the server so it can verify the password
958    let proof = hex::encode(verifier.get_proof());
959
960    // Send proof data
961    socket.write_all(&cont_auth(
962        proof.as_bytes(),
963        plugin,
964        AuthPluginType::plugin_list(),
965        &[],
966    ))?;
967    socket.flush()?;
968
969    read_response(&mut socket, buff, &mut 0)?;
970
971    // Enable wire encryption
972    socket.write_all(&crypt("Arc4", "Symmetric"))?;
973    socket.flush()?;
974
975    socket = FbStream::Arc4(Arc4Stream::new(
976        match socket {
977            FbStream::Plain(s) => s,
978            _ => unreachable!("Stream was already encrypted!"),
979        },
980        &verifier.get_key(),
981        buff.len(),
982    ));
983
984    read_response(&mut socket, buff, &mut 0)?;
985
986    Ok(socket)
987}
988
989#[derive(Debug, Clone, Copy)]
990/// A database handle
991pub struct DbHandle(u32);
992
993#[derive(Debug, Clone, Copy)]
994/// A transaction handle
995pub struct TrHandle(u32);
996
997#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
998/// A statement handle
999pub struct StmtHandle(u32);
1000
1001#[derive(Debug, Clone, Copy)]
1002/// A blob handle
1003pub struct BlobHandle(u32);
1004
1005#[derive(Debug, Clone, Copy)]
1006/// A blob Identificator
1007pub struct BlobId(pub(crate) u64);
1008
1009/// Firebird tcp stream, may be encrypted
1010enum FbStream {
1011    /// Plaintext stream
1012    Plain(TcpStream),
1013
1014    /// Arc4 ecrypted stream
1015    Arc4(Arc4Stream<TcpStream>),
1016}
1017
1018impl Read for FbStream {
1019    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1020        match self {
1021            FbStream::Plain(s) => s.read(buf),
1022            FbStream::Arc4(s) => s.read(buf),
1023        }
1024    }
1025}
1026
1027impl Write for FbStream {
1028    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1029        match self {
1030            FbStream::Plain(s) => s.write(buf),
1031            FbStream::Arc4(s) => s.write(buf),
1032        }
1033    }
1034
1035    fn flush(&mut self) -> std::io::Result<()> {
1036        match self {
1037            FbStream::Plain(s) => s.flush(),
1038            FbStream::Arc4(s) => s.flush(),
1039        }
1040    }
1041}
1042
1043#[test]
1044#[ignore]
1045fn connection_test() {
1046    use rsfbclient_core::charset::UTF_8;
1047
1048    let db_name = "test.fdb";
1049    let user = "SYSDBA";
1050    let pass = "masterkey";
1051
1052    let mut conn =
1053        FirebirdWireConnection::connect("127.0.0.1", 3050, db_name, user, pass, UTF_8).unwrap();
1054
1055    let mut db_handle = conn
1056        .attach_database(db_name, user, pass, None, Dialect::D3, false)
1057        .unwrap();
1058
1059    let mut tr_handle = conn
1060        .begin_transaction(&mut db_handle, TransactionConfiguration::default())
1061        .unwrap();
1062
1063    let (stmt_type, mut stmt_handle) = conn
1064        .prepare_statement(
1065            &mut db_handle,
1066            &mut tr_handle,
1067            Dialect::D3,
1068            "
1069            SELECT
1070                1, 'abcdefghij' as tst, rand(), CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP, -1, -2, -3, -4, -5, 1, 2, 3, 4, 5, 0 as last
1071            FROM RDB$DATABASE where 1 = ?
1072            ",
1073            // "
1074            // SELECT cast(1 as bigint), cast('abcdefghij' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1075            // SELECT cast(2 as bigint), cast('abcdefgh' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1076            // SELECT cast(3 as bigint), cast('abcdef' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1077            // SELECT cast(4 as bigint), cast(null as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1078            // SELECT cast(null as bigint), cast('abcd' as varchar(10)) as tst FROM RDB$DATABASE
1079            // ",
1080        )
1081        .unwrap();
1082
1083    println!("Statement type: {:?}", stmt_type);
1084
1085    let params = match rsfbclient_core::IntoParams::to_params((1,)) {
1086        rsfbclient_core::ParamsType::Positional(params) => params,
1087        _ => unreachable!(),
1088    };
1089
1090    conn.execute(&mut tr_handle, &mut stmt_handle, &params)
1091        .unwrap();
1092
1093    loop {
1094        let resp = conn.fetch(&mut tr_handle, &mut stmt_handle).unwrap();
1095
1096        if resp.is_none() {
1097            break;
1098        }
1099        println!("Fetch Resp: {:#?}", resp);
1100    }
1101
1102    std::thread::sleep(std::time::Duration::from_millis(100));
1103}