Skip to main content

rsfbclient_rust/
client.rs

1//! `FirebirdConnection` implementation for the pure rust firebird client
2
3use bytes::{BufMut, Bytes, BytesMut};
4use std::{
5    collections::VecDeque,
6    env,
7    io::{Read, Write},
8    net::TcpStream,
9};
10
11use crate::{
12    arc4::*,
13    blr,
14    consts::{AuthPluginType, ProtocolVersion, WireOp},
15    srp::*,
16    util::*,
17    wire::*,
18    xsqlda::{parse_xsqlda, xsqlda_to_blr, PrepareInfo, XSqlVar, XSQLDA_DESCRIBE_VARS},
19};
20use rsfbclient_core::*;
21
22type RustDbHandle = DbHandle;
23type RustTrHandle = TrHandle;
24type RustStmtHandle = StmtHandle;
25
26/// How many rows to request per op_fetch (round-trip). Configurable via
27/// FB_FETCH_BATCH; defaults to 200. The crate originally used 1 (one row per round-trip).
28fn fetch_batch_size() -> u32 {
29    env::var("FB_FETCH_BATCH")
30        .ok()
31        .and_then(|v| v.parse().ok())
32        .filter(|&n| n > 0)
33        .unwrap_or(200)
34}
35
36/// Result of parsing ONE op_fetch_response.
37enum FetchOne {
38    /// A row (status=0, messages=1).
39    Row(Vec<Column>),
40    /// End of THIS batch (status=0, messages=0): the server ended the op_fetch
41    /// without exhausting the cursor. Re-issuing op_fetch fetches the rest.
42    BatchEnd,
43    /// End of cursor (status=100). Nothing more to read.
44    End,
45}
46
47enum FetchErr {
48    /// Not enough bytes in the buffer — read more from the socket and retry.
49    NeedMore,
50    /// A real protocol/server error — propagate it.
51    Fatal(FbError),
52}
53
54/// Firebird client implemented in pure rust
55pub struct RustFbClient {
56    conn: Option<FirebirdWireConnection>,
57    charset: Charset,
58}
59
60/// Required configuration for an attachment with the pure rust client
61#[derive(Default, Clone)]
62pub struct RustFbClientAttachmentConfig {
63    pub host: String,
64    pub port: u16,
65    pub db_name: String,
66    pub user: String,
67    pub pass: String,
68    pub role_name: Option<String>,
69}
70
71/// A Connection to a firebird server
72pub struct FirebirdWireConnection {
73    /// Connection socket
74    socket: FbStream,
75
76    /// Wire protocol version
77    pub(crate) version: ProtocolVersion,
78
79    /// Buffer to read the network data
80    buff: Box<[u8]>,
81
82    /// Lazy responses to read
83    lazy_count: u32,
84
85    pub(crate) charset: Charset,
86}
87
88/// Data to keep track about a prepared statement
89pub struct StmtHandleData {
90    /// Statement handle
91    handle: RustStmtHandle,
92    /// Output xsqlda
93    xsqlda: Vec<XSqlVar>,
94    /// Blr representation of the above
95    blr: Bytes,
96    /// Number of parameters
97    param_count: usize,
98    /// Rows already fetched in a batch but not yet delivered (batch fetch).
99    prefetched: VecDeque<Vec<Column>>,
100    /// Cursor exhausted on the server (do not request more batches).
101    cursor_eof: bool,
102}
103
104impl RustFbClient {
105    ///Construct a new instance of the pure rust client
106    pub fn new(charset: Charset) -> Self {
107        Self {
108            conn: None,
109            charset,
110        }
111    }
112}
113
114impl FirebirdClientDbOps for RustFbClient {
115    type DbHandle = RustDbHandle;
116    type AttachmentConfig = RustFbClientAttachmentConfig;
117
118    fn attach_database(
119        &mut self,
120        config: &Self::AttachmentConfig,
121        dialect: Dialect,
122        no_db_triggers: bool,
123    ) -> Result<RustDbHandle, FbError> {
124        let host = config.host.as_str();
125        let port = config.port;
126        let db_name = config.db_name.as_str();
127        let user = config.user.as_str();
128        let pass = config.pass.as_str();
129        let role = match &config.role_name {
130            Some(ro) => Some(ro.as_str()),
131            None => None,
132        };
133
134        // Take the existing connection, or connects
135        let mut conn = match self.conn.take() {
136            Some(conn) => conn,
137            None => FirebirdWireConnection::connect(
138                host,
139                port,
140                db_name,
141                user,
142                pass,
143                self.charset.clone(),
144            )?,
145        };
146
147        let attach_result =
148            conn.attach_database(db_name, user, pass, role, dialect, no_db_triggers);
149
150        // Put the connection back
151        self.conn.replace(conn);
152
153        attach_result
154    }
155
156    fn detach_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
157        self.conn
158            .as_mut()
159            .map(|conn| conn.detach_database(db_handle))
160            .unwrap_or_else(err_client_not_connected)
161    }
162
163    fn drop_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
164        self.conn
165            .as_mut()
166            .map(|conn| conn.drop_database(db_handle))
167            .unwrap_or_else(err_client_not_connected)
168    }
169
170    fn create_database(
171        &mut self,
172        config: &Self::AttachmentConfig,
173        page_size: Option<u32>,
174        dialect: Dialect,
175    ) -> Result<RustDbHandle, FbError> {
176        let host = config.host.as_str();
177        let port = config.port;
178        let db_name = config.db_name.as_str();
179        let user = config.user.as_str();
180        let pass = config.pass.as_str();
181        let role = match &config.role_name {
182            Some(ro) => Some(ro.as_str()),
183            None => None,
184        };
185
186        // Take the existing connection, or connects
187        let mut conn = match self.conn.take() {
188            Some(conn) => conn,
189            None => FirebirdWireConnection::connect(
190                host,
191                port,
192                db_name,
193                user,
194                pass,
195                self.charset.clone(),
196            )?,
197        };
198
199        let attach_result = conn.create_database(db_name, user, pass, page_size, role, dialect);
200
201        // Put the connection back
202        self.conn.replace(conn);
203
204        attach_result
205    }
206}
207
208impl FirebirdClientSqlOps for RustFbClient {
209    type DbHandle = RustDbHandle;
210    type TrHandle = RustTrHandle;
211    type StmtHandle = StmtHandleData;
212
213    fn begin_transaction(
214        &mut self,
215        db_handle: &mut Self::DbHandle,
216        confs: TransactionConfiguration,
217    ) -> Result<Self::TrHandle, FbError> {
218        self.conn
219            .as_mut()
220            .map(|conn| conn.begin_transaction(db_handle, confs))
221            .unwrap_or_else(err_client_not_connected)
222    }
223
224    fn transaction_operation(
225        &mut self,
226        tr_handle: &mut Self::TrHandle,
227        op: TrOp,
228    ) -> Result<(), FbError> {
229        self.conn
230            .as_mut()
231            .map(|conn| conn.transaction_operation(tr_handle, op))
232            .unwrap_or_else(err_client_not_connected)
233    }
234
235    fn exec_immediate(
236        &mut self,
237        _db_handle: &mut Self::DbHandle,
238        tr_handle: &mut Self::TrHandle,
239        dialect: Dialect,
240        sql: &str,
241    ) -> Result<(), FbError> {
242        self.conn
243            .as_mut()
244            .map(|conn| conn.exec_immediate(tr_handle, dialect, sql))
245            .unwrap_or_else(err_client_not_connected)
246    }
247
248    fn prepare_statement(
249        &mut self,
250        db_handle: &mut Self::DbHandle,
251        tr_handle: &mut Self::TrHandle,
252        dialect: Dialect,
253        sql: &str,
254    ) -> Result<(StmtType, Self::StmtHandle), FbError> {
255        self.conn
256            .as_mut()
257            .map(|conn| conn.prepare_statement(db_handle, tr_handle, dialect, sql))
258            .unwrap_or_else(err_client_not_connected)
259    }
260
261    fn free_statement(
262        &mut self,
263        stmt_handle: &mut Self::StmtHandle,
264        op: FreeStmtOp,
265    ) -> Result<(), FbError> {
266        self.conn
267            .as_mut()
268            .map(|conn| conn.free_statement(stmt_handle, op))
269            .unwrap_or_else(err_client_not_connected)
270    }
271
272    fn execute(
273        &mut self,
274        _db_handle: &mut Self::DbHandle,
275        tr_handle: &mut Self::TrHandle,
276        stmt_handle: &mut Self::StmtHandle,
277        params: Vec<SqlType>,
278    ) -> Result<usize, FbError> {
279        self.conn
280            .as_mut()
281            .map(|conn| conn.execute(tr_handle, stmt_handle, &params))
282            .unwrap_or_else(err_client_not_connected)
283    }
284
285    fn execute2(
286        &mut self,
287        _db_handle: &mut Self::DbHandle,
288        tr_handle: &mut Self::TrHandle,
289        stmt_handle: &mut Self::StmtHandle,
290        params: Vec<SqlType>,
291    ) -> Result<Vec<Column>, FbError> {
292        self.conn
293            .as_mut()
294            .map(|conn| conn.execute2(tr_handle, stmt_handle, &params))
295            .unwrap_or_else(err_client_not_connected)
296    }
297
298    fn fetch(
299        &mut self,
300        _db_handle: &mut Self::DbHandle,
301        tr_handle: &mut Self::TrHandle,
302        stmt_handle: &mut Self::StmtHandle,
303    ) -> Result<Option<Vec<Column>>, FbError> {
304        self.conn
305            .as_mut()
306            .map(|conn| conn.fetch(tr_handle, stmt_handle))
307            .unwrap_or_else(err_client_not_connected)
308    }
309}
310
311fn err_client_not_connected<T>() -> Result<T, FbError> {
312    Err("Client not connected to the server, call `attach_database` to connect".into())
313}
314
315impl FirebirdWireConnection {
316    /// Start a connection to the firebird server
317    pub fn connect(
318        host: &str,
319        port: u16,
320        db_name: &str,
321        user: &str,
322        pass: &str,
323        charset: Charset,
324    ) -> Result<Self, FbError> {
325        let socket = TcpStream::connect((host, port))?;
326
327        // System username
328        let username =
329            env::var("USER").unwrap_or_else(|_| env::var("USERNAME").unwrap_or_default());
330        let hostname = socket
331            .local_addr()
332            .map(|addr| addr.to_string())
333            .unwrap_or_default();
334
335        let mut socket = FbStream::Plain(socket);
336
337        // Random key for the srp
338        let srp_key: [u8; 32] = rand::random();
339
340        let req = connect(db_name, user, &username, &hostname, &srp_key);
341        socket.write_all(&req)?;
342        socket.flush()?;
343
344        // May be a bit too much
345        let mut buff = vec![0; BUFFER_LENGTH as usize * 2].into_boxed_slice();
346
347        let len = socket.read(&mut buff)?;
348        let mut resp = Bytes::copy_from_slice(&buff[..len]);
349
350        let ConnectionResponse {
351            version,
352            auth_plugin,
353        } = parse_accept(&mut resp)?;
354
355        if let Some(mut auth_plugin) = auth_plugin {
356            loop {
357                match auth_plugin.kind {
358                    plugin @ AuthPluginType::Srp => {
359                        let srp = SrpClient::<sha1::Sha1>::new(&srp_key, &SRP_GROUP);
360
361                        if let Some(data) = auth_plugin.data {
362                            socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
363
364                            // Authentication Ok
365                            break;
366                        } else {
367                            // Server requested a different authentication method than the client specified
368                            // in the initial connection
369
370                            socket.write_all(&cont_auth(
371                                hex::encode(srp.get_a_pub()).as_bytes(),
372                                plugin,
373                                AuthPluginType::plugin_list(),
374                                &[],
375                            ))?;
376                            socket.flush()?;
377
378                            let len = socket.read(&mut buff)?;
379                            let mut resp = Bytes::copy_from_slice(&buff[..len]);
380
381                            auth_plugin = parse_cont_auth(&mut resp)?;
382                        }
383                    }
384                    plugin @ AuthPluginType::Srp256 => {
385                        let srp = SrpClient::<sha2::Sha256>::new(&srp_key, &SRP_GROUP);
386
387                        if let Some(data) = auth_plugin.data {
388                            socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
389
390                            // Authentication Ok
391                            break;
392                        } else {
393                            // Server requested a different authentication method than the client specified
394                            // in the initial connection
395
396                            socket.write_all(&cont_auth(
397                                hex::encode(srp.get_a_pub()).as_bytes(),
398                                plugin,
399                                AuthPluginType::plugin_list(),
400                                &[],
401                            ))?;
402                            socket.flush()?;
403
404                            let len = socket.read(&mut buff)?;
405                            let mut resp = Bytes::copy_from_slice(&buff[..len]);
406
407                            auth_plugin = parse_cont_auth(&mut resp)?;
408                        }
409                    }
410                }
411            }
412        }
413
414        Ok(Self {
415            socket,
416            version,
417            buff,
418            lazy_count: 0,
419            charset,
420        })
421    }
422
423    /// Create the database and attach, returning a database handle
424    pub fn create_database(
425        &mut self,
426        db_name: &str,
427        user: &str,
428        pass: &str,
429        page_size: Option<u32>,
430        role_name: Option<&str>,
431        dialect: Dialect,
432    ) -> Result<DbHandle, FbError> {
433        self.socket.write_all(&create(
434            db_name,
435            user,
436            pass,
437            self.version,
438            self.charset.clone(),
439            page_size,
440            role_name.clone(),
441            dialect,
442        ))?;
443        self.socket.flush()?;
444
445        let resp = self.read_response()?;
446
447        Ok(DbHandle(resp.handle))
448    }
449
450    /// Connect to a database, returning a database handle
451    pub fn attach_database(
452        &mut self,
453        db_name: &str,
454        user: &str,
455        pass: &str,
456        role_name: Option<&str>,
457        dialect: Dialect,
458        no_db_triggers: bool,
459    ) -> Result<DbHandle, FbError> {
460        self.socket.write_all(&attach(
461            db_name,
462            user,
463            pass,
464            self.version,
465            self.charset.clone(),
466            role_name.clone(),
467            dialect,
468            no_db_triggers,
469        ))?;
470        self.socket.flush()?;
471
472        let resp = self.read_response()?;
473
474        Ok(DbHandle(resp.handle))
475    }
476
477    /// Disconnect from the database
478    pub fn detach_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
479        self.socket.write_all(&detach(db_handle.0))?;
480        self.socket.flush()?;
481
482        self.read_response()?;
483
484        Ok(())
485    }
486
487    /// Drop the database
488    pub fn drop_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
489        self.socket.write_all(&drop_database(db_handle.0))?;
490        self.socket.flush()?;
491
492        self.read_response()?;
493
494        Ok(())
495    }
496
497    /// Start a new transaction, with the specified transaction parameter buffer
498    pub fn begin_transaction(
499        &mut self,
500        db_handle: &mut DbHandle,
501        confs: TransactionConfiguration,
502    ) -> Result<TrHandle, FbError> {
503        let mut tpb = vec![
504            ibase::isc_tpb_version3 as u8,
505            confs.isolation.into(),
506            confs.data_access as u8,
507            confs.lock_resolution.into(),
508        ];
509        if let TrLockResolution::Wait(Some(time)) = confs.lock_resolution {
510            tpb.push(ibase::isc_tpb_lock_timeout as u8);
511            tpb.push(4 as u8);
512            tpb.extend_from_slice(&time.to_le_bytes());
513        }
514
515        if let TrIsolationLevel::ReadCommited(rec) = confs.isolation {
516            tpb.push(rec as u8);
517        }
518
519        self.socket.write_all(&transaction(db_handle.0, &tpb))?;
520        self.socket.flush()?;
521
522        let resp = self.read_response()?;
523
524        Ok(TrHandle(resp.handle))
525    }
526
527    /// Commit / Rollback a transaction
528    pub fn transaction_operation(
529        &mut self,
530        tr_handle: &mut TrHandle,
531        op: TrOp,
532    ) -> Result<(), FbError> {
533        self.socket
534            .write_all(&transaction_operation(tr_handle.0, op))?;
535        self.socket.flush()?;
536
537        self.read_response()?;
538
539        Ok(())
540    }
541
542    /// Execute a sql immediately, without returning rows
543    pub fn exec_immediate(
544        &mut self,
545        tr_handle: &mut TrHandle,
546        dialect: Dialect,
547        sql: &str,
548    ) -> Result<(), FbError> {
549        self.socket.write_all(&exec_immediate(
550            tr_handle.0,
551            dialect as u32,
552            sql,
553            &self.charset,
554        )?)?;
555        self.socket.flush()?;
556
557        self.read_response()?;
558
559        Ok(())
560    }
561
562    /// Alloc and prepare a statement
563    ///
564    /// Returns the statement type, handle and xsqlda describing the columns
565    pub fn prepare_statement(
566        &mut self,
567        db_handle: &mut DbHandle,
568        tr_handle: &mut TrHandle,
569        dialect: Dialect,
570        sql: &str,
571    ) -> Result<(StmtType, StmtHandleData), FbError> {
572        // Alloc statement
573        self.socket.write_all(&allocate_statement(db_handle.0))?;
574        // Prepare statement
575        self.socket.write_all(&prepare_statement(
576            tr_handle.0,
577            u32::MAX,
578            dialect as u32,
579            sql,
580            &self.charset,
581        )?)?;
582        self.socket.flush()?;
583
584        let (mut op_code, mut resp) = self.read_packet()?;
585
586        // Read lazy responses
587        for _ in 0..self.lazy_count {
588            if op_code != WireOp::Response as u32 {
589                return err_conn_rejected(op_code);
590            }
591            self.lazy_count -= 1;
592            parse_response(&mut resp)?;
593
594            op_code = resp.get_u32()?;
595        }
596
597        // Alloc resp
598        if op_code != WireOp::Response as u32 {
599            return err_conn_rejected(op_code);
600        }
601
602        let stmt_handle = StmtHandle(parse_response(&mut resp)?.handle);
603
604        // Prepare resp
605        let op_code = resp.get_u32()?;
606
607        if op_code != WireOp::Response as u32 {
608            return err_conn_rejected(op_code);
609        }
610
611        let mut xsqlda = Vec::new();
612
613        let mut resp = parse_response(&mut resp)?;
614        let PrepareInfo {
615            stmt_type,
616            mut param_count,
617            mut truncated,
618        } = parse_xsqlda(&mut resp.data, &mut xsqlda)?;
619
620        while truncated {
621            // Get more info on the types
622            let next_index = (xsqlda.len() as u16).to_le_bytes();
623
624            self.socket.write_all(&info_sql(
625                stmt_handle.0,
626                &[
627                    &[
628                        ibase::isc_info_sql_sqlda_start as u8, // Describe a xsqlda
629                        2,
630                        next_index[0], // Index, first byte
631                        next_index[1], // Index, second byte
632                    ],
633                    &XSQLDA_DESCRIBE_VARS[..], // Data to be returned
634                ]
635                .concat(),
636            ))?;
637            self.socket.flush()?;
638
639            let mut data = self.read_response()?.data;
640
641            let parse_resp = parse_xsqlda(&mut data, &mut xsqlda)?;
642            truncated = parse_resp.truncated;
643            param_count = parse_resp.param_count;
644        }
645
646        // Coerce the output columns and transform to blr
647        for var in xsqlda.iter_mut() {
648            var.coerce()?;
649        }
650        let blr = xsqlda_to_blr(&xsqlda)?;
651
652        Ok((
653            stmt_type,
654            StmtHandleData {
655                handle: stmt_handle,
656                xsqlda,
657                blr,
658                param_count,
659                prefetched: VecDeque::new(),
660                cursor_eof: false,
661            },
662        ))
663    }
664
665    /// Closes or drops a statement
666    pub fn free_statement(
667        &mut self,
668        stmt_handle: &mut StmtHandleData,
669        op: FreeStmtOp,
670    ) -> Result<(), FbError> {
671        self.socket
672            .write_all(&free_statement(stmt_handle.handle.0, op))?;
673        // Obs.: Lazy response
674
675        self.lazy_count += 1;
676
677        Ok(())
678    }
679
680    /// Execute the prepared statement with parameters
681    pub fn execute(
682        &mut self,
683        tr_handle: &mut TrHandle,
684        stmt_handle: &mut StmtHandleData,
685        params: &[SqlType],
686    ) -> Result<usize, FbError> {
687        if params.len() != stmt_handle.param_count {
688            return Err(format!(
689                "Tried to execute a statement that has {} parameters while providing {}",
690                stmt_handle.param_count,
691                params.len()
692            )
693            .into());
694        }
695
696        // Reopen the cursor: drop prefetched rows and the batch-fetch EOF flag
697        // from the previous execution. Without this, re-executing the same
698        // statement would inherit cursor_eof=true and fetch nothing.
699        stmt_handle.prefetched.clear();
700        stmt_handle.cursor_eof = false;
701
702        // Execute
703        let params = blr::params_to_blr(self, tr_handle, params)?;
704
705        self.socket.write_all(&execute(
706            tr_handle.0,
707            stmt_handle.handle.0,
708            &params.blr,
709            &params.values,
710        ))?;
711        self.socket.flush()?;
712
713        self.read_response()?;
714
715        // Get affected rows
716        self.socket.write_all(&info_sql(
717            stmt_handle.handle.0,
718            &[ibase::isc_info_sql_records as u8], // Request affected rows,
719        ))?;
720        self.socket.flush()?;
721
722        let mut data = self.read_response()?.data;
723
724        parse_info_sql_affected_rows(&mut data)
725    }
726
727    /// Execute the prepared statement with parameters, returning data
728    pub fn execute2(
729        &mut self,
730        tr_handle: &mut TrHandle,
731        stmt_handle: &mut StmtHandleData,
732        params: &[SqlType],
733    ) -> Result<Vec<Column>, FbError> {
734        if params.len() != stmt_handle.param_count {
735            return Err(format!(
736                "Tried to execute a statement that has {} parameters while providing {}",
737                stmt_handle.param_count,
738                params.len()
739            )
740            .into());
741        }
742
743        // Reopen the cursor (same reason as execute): reset the batch-fetch
744        // state from the previous execution.
745        stmt_handle.prefetched.clear();
746        stmt_handle.cursor_eof = false;
747
748        let params = blr::params_to_blr(self, tr_handle, params)?;
749
750        self.socket.write_all(&execute2(
751            tr_handle.0,
752            stmt_handle.handle.0,
753            &params.blr,
754            &params.values,
755            &stmt_handle.blr,
756        ))?;
757        self.socket.flush()?;
758
759        let (mut op_code, mut resp) = read_packet(&mut self.socket, &mut self.buff)?;
760
761        // Read lazy responses
762        for _ in 0..self.lazy_count {
763            if op_code != WireOp::Response as u32 {
764                return err_conn_rejected(op_code);
765            }
766            self.lazy_count -= 1;
767            parse_response(&mut resp)?;
768
769            op_code = resp.get_u32()?;
770        }
771
772        if op_code == WireOp::Response as u32 {
773            // An error ocurred
774            parse_response(&mut resp)?;
775        }
776
777        if op_code != WireOp::SqlResponse as u32 {
778            return err_conn_rejected(op_code);
779        }
780
781        let parsed_cols =
782            parse_sql_response(&mut resp, &stmt_handle.xsqlda, self.version, &self.charset)?;
783
784        parse_response(&mut resp)?;
785
786        let mut cols = Vec::with_capacity(parsed_cols.len());
787
788        for pc in parsed_cols {
789            cols.push(pc.into_column(self, tr_handle)?);
790        }
791
792        Ok(cols)
793    }
794
795    /// Fetch ONE row. Served from a buffer filled in batches: when the buffer
796    /// empties, a single op_fetch requests `FB_FETCH_BATCH` rows in one
797    /// round-trip (it used to be one row per round-trip). Streaming is
798    /// preserved — rows come out one at a time, memory bounded to one batch.
799    pub fn fetch(
800        &mut self,
801        tr_handle: &mut TrHandle,
802        stmt_handle: &mut StmtHandleData,
803    ) -> Result<Option<Vec<Column>>, FbError> {
804        let count = fetch_batch_size();
805        let mut empty_batches = 0u32;
806        while stmt_handle.prefetched.is_empty() && !stmt_handle.cursor_eof {
807            self.fetch_batch(tr_handle, stmt_handle, count)?;
808            empty_batches += 1;
809            // Safety net: a well-behaved server never sends empty batches
810            // without exhausting the cursor; guards against a hang if it does.
811            if empty_batches > 1000 {
812                return Err("fetch: too many empty batches without end of cursor".into());
813            }
814        }
815        Ok(stmt_handle.prefetched.pop_front())
816    }
817
818    /// Requests `count` rows in one op_fetch and reads every op_fetch_response
819    /// that arrives, filling `stmt_handle.prefetched`. Parses greedily; if bytes
820    /// are missing mid-response, reads more from the socket and resumes (robust
821    /// against responses split across TCP segments — the cause of the old
822    /// "Invalid server response, missing bytes").
823    fn fetch_batch(
824        &mut self,
825        tr_handle: &mut TrHandle,
826        stmt_handle: &mut StmtHandleData,
827        count: u32,
828    ) -> Result<(), FbError> {
829        self.socket
830            .write_all(&fetch(stmt_handle.handle.0, &stmt_handle.blr, count))?;
831        self.socket.flush()?;
832
833        let mut acc = BytesMut::new();
834        let mut got = 0u32;
835
836        loop {
837            // Parse as many responses as the accumulated bytes allow.
838            let mut view = std::mem::take(&mut acc).freeze();
839            loop {
840                let snapshot = view.clone(); // O(1): Bytes shares the underlying buffer
841                let saved_lazy = self.lazy_count;
842                match self.parse_one_fetch_response(&mut view, &stmt_handle.xsqlda, tr_handle) {
843                    Ok(FetchOne::Row(cols)) => {
844                        stmt_handle.prefetched.push_back(cols);
845                        got += 1;
846                        // Do NOT return on got>=count: after the rows, the server
847                        // always sends a terminating op_fetch_response (messages=0
848                        // = end of this batch, or status=100 = end of cursor).
849                        // Returning early would discard that terminator (still
850                        // buffered) and desync the next op_fetch. Let BatchEnd/End
851                        // end the loop. Guard against a server sending more rows
852                        // than requested (should not happen).
853                        if got > count {
854                            return Err("server sent more rows than requested in op_fetch".into());
855                        }
856                    }
857                    Ok(FetchOne::BatchEnd) => {
858                        // Server ended this op_fetch without exhausting the cursor.
859                        // Deliver what arrived; the next fetch() re-issues op_fetch.
860                        return Ok(());
861                    }
862                    Ok(FetchOne::End) => {
863                        stmt_handle.cursor_eof = true;
864                        return Ok(());
865                    }
866                    Err(FetchErr::NeedMore) => {
867                        self.lazy_count = saved_lazy; // undo partial lazy consumption
868                        view = snapshot;
869                        break;
870                    }
871                    Err(FetchErr::Fatal(e)) => return Err(e),
872                }
873            }
874
875            // Missing bytes: keep the unconsumed tail and read more from the socket.
876            let mut next = BytesMut::from(view.as_ref());
877            let n = self.socket.read(&mut self.buff)?;
878            if n == 0 {
879                return Err("Fetch: connection closed mid-batch".into());
880            }
881            next.extend_from_slice(&self.buff[..n]);
882            acc = next;
883        }
884    }
885
886    /// Tries to parse ONE op_fetch_response from `view`. On NeedMore, the caller
887    /// restores `view` from a snapshot (the partial consumption here is discarded).
888    fn parse_one_fetch_response(
889        &mut self,
890        view: &mut Bytes,
891        xsqlda: &[XSqlVar],
892        tr_handle: &mut TrHandle,
893    ) -> Result<FetchOne, FetchErr> {
894        // op_code, skipping Dummy packets
895        let mut op_code = loop {
896            if view.remaining() < 4 {
897                return Err(FetchErr::NeedMore);
898            }
899            let oc = view.get_u32().map_err(|_| FetchErr::NeedMore)?;
900            if oc != WireOp::Dummy as u32 {
901                break oc;
902            }
903        };
904
905        // Pending lazy responses
906        for _ in 0..self.lazy_count {
907            if op_code != WireOp::Response as u32 {
908                return Err(FetchErr::Fatal(
909                    format!("unexpected op_code in fetch (op {})", op_code).into(),
910                ));
911            }
912            self.lazy_count -= 1;
913            parse_response(view).map_err(|_| FetchErr::NeedMore)?;
914            if view.remaining() < 4 {
915                return Err(FetchErr::NeedMore);
916            }
917            op_code = view.get_u32().map_err(|_| FetchErr::NeedMore)?;
918        }
919
920        if op_code == WireOp::Response as u32 {
921            // Error reported by the server
922            parse_response(view).map_err(FetchErr::Fatal)?;
923        }
924
925        if op_code != WireOp::FetchResponse as u32 {
926            return Err(FetchErr::Fatal(
927                format!("unexpected op_code in fetch (op {})", op_code).into(),
928            ));
929        }
930
931        // Body: [status: u32][messages: u32][null_map][columns...]. Peek status
932        // and messages without consuming, to tell end-of-cursor (status=100),
933        // end-of-batch (messages=0) and a row (messages=1) apart BEFORE delegating.
934        if view.remaining() < 4 {
935            return Err(FetchErr::NeedMore);
936        }
937        let status = {
938            let mut peek = view.clone();
939            peek.get_u32().map_err(|_| FetchErr::NeedMore)?
940        };
941        if status == 100 {
942            // End of cursor: consume only the status (same as parse_fetch_response).
943            view.advance(4).map_err(|_| FetchErr::NeedMore)?;
944            return Ok(FetchOne::End);
945        }
946        if view.remaining() < 8 {
947            return Err(FetchErr::NeedMore);
948        }
949        let messages = {
950            let mut peek = view.clone();
951            peek.get_u32().map_err(|_| FetchErr::NeedMore)?; // status
952            peek.get_u32().map_err(|_| FetchErr::NeedMore)? // messages
953        };
954        if messages == 0 {
955            // End of this batch with no row: consume status+messages, stop the batch.
956            view.advance(8).map_err(|_| FetchErr::NeedMore)?;
957            return Ok(FetchOne::BatchEnd);
958        }
959
960        // A row is present. Delegate to the crate parser (re-reads status+messages+data).
961        // charset cloned so we don't hold a borrow of self when calling into_column.
962        let version = self.version;
963        let charset = self.charset.clone();
964        match parse_fetch_response(view, xsqlda, version, &charset) {
965            Ok(None) => Ok(FetchOne::End),
966            Ok(Some(parsed)) => {
967                let mut cols = Vec::with_capacity(parsed.len());
968                for pc in parsed {
969                    cols.push(pc.into_column(self, tr_handle).map_err(FetchErr::Fatal)?);
970                }
971                Ok(FetchOne::Row(cols))
972            }
973            // Underflow (bytes missing mid-response) has a fixed message -> read
974            // more and resume. Any OTHER error (e.g. invalid UTF-8 when decoding a
975            // column) is a real data error: propagate it, don't turn it into an
976            // infinite wait for bytes that never arrive.
977            Err(e) => {
978                if matches!(&e, FbError::Other(m) if m == "Invalid server response, missing bytes")
979                {
980                    Err(FetchErr::NeedMore)
981                } else {
982                    Err(FetchErr::Fatal(e))
983                }
984            }
985        }
986    }
987
988    /// Create a new blob, returning the blob handle and id
989    pub fn create_blob(
990        &mut self,
991        tr_handle: &mut TrHandle,
992    ) -> Result<(BlobHandle, BlobId), FbError> {
993        self.socket.write_all(&create_blob(tr_handle.0))?;
994        self.socket.flush()?;
995
996        let resp = self.read_response()?;
997
998        Ok((BlobHandle(resp.handle), BlobId(resp.object_id)))
999    }
1000
1001    /// Put blob segments
1002    pub fn put_segments(&mut self, blob_handle: BlobHandle, data: &[u8]) -> Result<(), FbError> {
1003        for segment in data.chunks(crate::blr::MAX_DATA_LENGTH) {
1004            self.socket
1005                .write_all(&put_segment(blob_handle.0, segment))?;
1006            self.socket.flush()?;
1007
1008            self.read_response()?;
1009        }
1010
1011        Ok(())
1012    }
1013
1014    /// Open a blob, returning the blob handle
1015    pub fn open_blob(
1016        &mut self,
1017        tr_handle: &mut TrHandle,
1018        blob_id: BlobId,
1019    ) -> Result<BlobHandle, FbError> {
1020        self.socket.write_all(&open_blob(tr_handle.0, blob_id.0))?;
1021        self.socket.flush()?;
1022
1023        let resp = self.read_response()?;
1024
1025        Ok(BlobHandle(resp.handle))
1026    }
1027
1028    /// Get a blob segment, returns the bytes and true if there is more data
1029    pub fn get_segment(&mut self, blob_handle: BlobHandle) -> Result<(Bytes, bool), FbError> {
1030        self.socket.write_all(&get_segment(blob_handle.0))?;
1031        self.socket.flush()?;
1032
1033        let mut blob_data = BytesMut::with_capacity(256);
1034
1035        let resp = self.read_response()?;
1036        let mut data = resp.data;
1037
1038        loop {
1039            if data.remaining() < 2 {
1040                break;
1041            }
1042            let len = data.get_u16_le()? as usize;
1043            if data.remaining() < len {
1044                return err_invalid_response();
1045            }
1046            blob_data.put_slice(&data[..len]);
1047            data.advance(len)?;
1048        }
1049
1050        Ok((blob_data.freeze(), resp.handle == 2))
1051    }
1052
1053    /// Closes a blob handle
1054    pub fn close_blob(&mut self, blob_handle: BlobHandle) -> Result<(), FbError> {
1055        self.socket.write_all(&close_blob(blob_handle.0))?;
1056        self.socket.flush()?;
1057
1058        self.read_response()?;
1059
1060        Ok(())
1061    }
1062
1063    /// Read a server response
1064    fn read_response(&mut self) -> Result<Response, FbError> {
1065        read_response(&mut self.socket, &mut self.buff, &mut self.lazy_count)
1066    }
1067
1068    /// Reads a packet from the socket
1069    fn read_packet(&mut self) -> Result<(u32, Bytes), FbError> {
1070        read_packet(&mut self.socket, &mut self.buff)
1071    }
1072}
1073
1074/// Read a server response
1075fn read_response(
1076    socket: &mut impl Read,
1077    buff: &mut [u8],
1078    lazy_count: &mut u32,
1079) -> Result<Response, FbError> {
1080    let (mut op_code, mut resp) = read_packet(socket, buff)?;
1081
1082    // Read lazy responses
1083    for _ in 0..*lazy_count {
1084        if op_code != WireOp::Response as u32 {
1085            return err_conn_rejected(op_code);
1086        }
1087        *lazy_count -= 1;
1088        parse_response(&mut resp)?;
1089
1090        op_code = resp.get_u32()?;
1091    }
1092
1093    if op_code != WireOp::Response as u32 {
1094        return err_conn_rejected(op_code);
1095    }
1096
1097    parse_response(&mut resp)
1098}
1099
1100/// Reads a packet from the socket
1101fn read_packet(socket: &mut impl Read, buff: &mut [u8]) -> Result<(u32, Bytes), FbError> {
1102    let mut len = socket.read(buff)?;
1103    let mut resp = BytesMut::from(&buff[..len]);
1104
1105    loop {
1106        if len == buff.len() {
1107            // The buffer was not large enough, so read more
1108            len = socket.read(buff)?;
1109            resp.put_slice(&buff[..len]);
1110        } else {
1111            break;
1112        }
1113    }
1114    let mut resp = resp.freeze();
1115
1116    let op_code = loop {
1117        let op_code = resp.get_u32()?;
1118
1119        if op_code != WireOp::Dummy as u32 {
1120            break op_code;
1121        }
1122    };
1123
1124    Ok((op_code, resp))
1125}
1126
1127/// Performs the srp authentication with the server, returning the encrypted stream
1128fn srp_auth<D>(
1129    mut socket: FbStream,
1130    buff: &mut [u8],
1131    srp: SrpClient<D>,
1132    plugin: AuthPluginType,
1133    user: &str,
1134    pass: &str,
1135    data: SrpAuthData,
1136) -> Result<FbStream, FbError>
1137where
1138    D: digest::Digest,
1139{
1140    // Generate a private key with the salt received from the server
1141    let private_key = srp_private_key::<sha1::Sha1>(user.as_bytes(), pass.as_bytes(), &data.salt);
1142
1143    // Generate a verified with the private key above and the server public key received
1144    let verifier = srp
1145        .process_reply(user.as_bytes(), &data.salt, &private_key, &data.pub_key)
1146        .map_err(|e| FbError::from(format!("Srp error: {}", e)))?;
1147
1148    // Generate a proof to send to the server so it can verify the password
1149    let proof = hex::encode(verifier.get_proof());
1150
1151    // Send proof data
1152    socket.write_all(&cont_auth(
1153        proof.as_bytes(),
1154        plugin,
1155        AuthPluginType::plugin_list(),
1156        &[],
1157    ))?;
1158    socket.flush()?;
1159
1160    read_response(&mut socket, buff, &mut 0)?;
1161
1162    // Enable wire encryption
1163    socket.write_all(&crypt("Arc4", "Symmetric"))?;
1164    socket.flush()?;
1165
1166    socket = FbStream::Arc4(Arc4Stream::new(
1167        match socket {
1168            FbStream::Plain(s) => s,
1169            _ => unreachable!("Stream was already encrypted!"),
1170        },
1171        &verifier.get_key(),
1172        buff.len(),
1173    ));
1174
1175    read_response(&mut socket, buff, &mut 0)?;
1176
1177    Ok(socket)
1178}
1179
1180#[derive(Debug, Clone, Copy)]
1181/// A database handle
1182pub struct DbHandle(u32);
1183
1184#[derive(Debug, Clone, Copy)]
1185/// A transaction handle
1186pub struct TrHandle(u32);
1187
1188#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
1189/// A statement handle
1190pub struct StmtHandle(u32);
1191
1192#[derive(Debug, Clone, Copy)]
1193/// A blob handle
1194pub struct BlobHandle(u32);
1195
1196#[derive(Debug, Clone, Copy)]
1197/// A blob Identificator
1198pub struct BlobId(pub(crate) u64);
1199
1200/// Firebird tcp stream, may be encrypted
1201enum FbStream {
1202    /// Plaintext stream
1203    Plain(TcpStream),
1204
1205    /// Arc4 ecrypted stream
1206    Arc4(Arc4Stream<TcpStream>),
1207}
1208
1209impl Read for FbStream {
1210    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1211        match self {
1212            FbStream::Plain(s) => s.read(buf),
1213            FbStream::Arc4(s) => s.read(buf),
1214        }
1215    }
1216}
1217
1218impl Write for FbStream {
1219    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1220        match self {
1221            FbStream::Plain(s) => s.write(buf),
1222            FbStream::Arc4(s) => s.write(buf),
1223        }
1224    }
1225
1226    fn flush(&mut self) -> std::io::Result<()> {
1227        match self {
1228            FbStream::Plain(s) => s.flush(),
1229            FbStream::Arc4(s) => s.flush(),
1230        }
1231    }
1232}
1233
1234#[test]
1235#[ignore]
1236fn connection_test() {
1237    use rsfbclient_core::charset::UTF_8;
1238
1239    let db_name = "test.fdb";
1240    let user = "SYSDBA";
1241    let pass = "masterkey";
1242
1243    let mut conn =
1244        FirebirdWireConnection::connect("127.0.0.1", 3050, db_name, user, pass, UTF_8).unwrap();
1245
1246    let mut db_handle = conn
1247        .attach_database(db_name, user, pass, None, Dialect::D3, false)
1248        .unwrap();
1249
1250    let mut tr_handle = conn
1251        .begin_transaction(&mut db_handle, TransactionConfiguration::default())
1252        .unwrap();
1253
1254    let (stmt_type, mut stmt_handle) = conn
1255        .prepare_statement(
1256            &mut db_handle,
1257            &mut tr_handle,
1258            Dialect::D3,
1259            "
1260            SELECT
1261                1, 'abcdefghij' as tst, rand(), CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP, -1, -2, -3, -4, -5, 1, 2, 3, 4, 5, 0 as last
1262            FROM RDB$DATABASE where 1 = ?
1263            ",
1264            // "
1265            // SELECT cast(1 as bigint), cast('abcdefghij' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1266            // SELECT cast(2 as bigint), cast('abcdefgh' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1267            // SELECT cast(3 as bigint), cast('abcdef' as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1268            // SELECT cast(4 as bigint), cast(null as varchar(10)) as tst FROM RDB$DATABASE UNION ALL
1269            // SELECT cast(null as bigint), cast('abcd' as varchar(10)) as tst FROM RDB$DATABASE
1270            // ",
1271        )
1272        .unwrap();
1273
1274    println!("Statement type: {:?}", stmt_type);
1275
1276    let params = match rsfbclient_core::IntoParams::to_params((1,)) {
1277        rsfbclient_core::ParamsType::Positional(params) => params,
1278        _ => unreachable!(),
1279    };
1280
1281    conn.execute(&mut tr_handle, &mut stmt_handle, &params)
1282        .unwrap();
1283
1284    loop {
1285        let resp = conn.fetch(&mut tr_handle, &mut stmt_handle).unwrap();
1286
1287        if resp.is_none() {
1288            break;
1289        }
1290        println!("Fetch Resp: {:#?}", resp);
1291    }
1292
1293    std::thread::sleep(std::time::Duration::from_millis(100));
1294}