use bytes::{BufMut, Bytes, BytesMut};
use std::{
env,
io::{Read, Write},
net::TcpStream,
};
use crate::{
arc4::*,
blr,
consts::{AuthPluginType, ProtocolVersion, WireOp},
srp::*,
util::*,
wire::*,
xsqlda::{parse_xsqlda, xsqlda_to_blr, PrepareInfo, XSqlVar, XSQLDA_DESCRIBE_VARS},
};
use rsfbclient_core::{
ibase, Charset, Column, Dialect, FbError, FirebirdClientDbOps, FirebirdClientSqlOps,
FreeStmtOp, SqlType, StmtType, TrIsolationLevel, TrOp,
};
type RustDbHandle = DbHandle;
type RustTrHandle = TrHandle;
type RustStmtHandle = StmtHandle;
pub struct RustFbClient {
conn: Option<FirebirdWireConnection>,
charset: Charset,
}
#[derive(Default, Clone)]
pub struct RustFbClientAttachmentConfig {
pub host: String,
pub port: u16,
pub db_name: String,
pub user: String,
pub pass: String,
}
pub struct FirebirdWireConnection {
socket: FbStream,
pub(crate) version: ProtocolVersion,
buff: Box<[u8]>,
lazy_count: u32,
pub(crate) charset: Charset,
}
pub struct StmtHandleData {
handle: RustStmtHandle,
xsqlda: Vec<XSqlVar>,
blr: Bytes,
param_count: usize,
}
impl RustFbClient {
pub fn new(charset: Charset) -> Self {
Self {
conn: None,
charset,
}
}
}
impl FirebirdClientDbOps for RustFbClient {
type DbHandle = RustDbHandle;
type AttachmentConfig = RustFbClientAttachmentConfig;
fn attach_database(
&mut self,
config: &Self::AttachmentConfig,
) -> Result<RustDbHandle, FbError> {
let host = config.host.as_str();
let port = config.port;
let db_name = config.db_name.as_str();
let user = config.user.as_str();
let pass = config.pass.as_str();
let mut conn = match self.conn.take() {
Some(conn) => conn,
None => FirebirdWireConnection::connect(
host,
port,
db_name,
user,
pass,
self.charset.clone(),
)?,
};
let attach_result = conn.attach_database(db_name, user, pass);
self.conn.replace(conn);
attach_result
}
fn detach_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
self.conn
.as_mut()
.map(|conn| conn.detach_database(db_handle))
.unwrap_or_else(err_client_not_connected)
}
fn drop_database(&mut self, db_handle: &mut RustDbHandle) -> Result<(), FbError> {
self.conn
.as_mut()
.map(|conn| conn.drop_database(db_handle))
.unwrap_or_else(err_client_not_connected)
}
}
impl FirebirdClientSqlOps for RustFbClient {
type DbHandle = RustDbHandle;
type TrHandle = RustTrHandle;
type StmtHandle = StmtHandleData;
fn begin_transaction(
&mut self,
db_handle: &mut Self::DbHandle,
isolation_level: TrIsolationLevel,
) -> Result<Self::TrHandle, FbError> {
self.conn
.as_mut()
.map(|conn| conn.begin_transaction(db_handle, isolation_level))
.unwrap_or_else(err_client_not_connected)
}
fn transaction_operation(
&mut self,
tr_handle: &mut Self::TrHandle,
op: TrOp,
) -> Result<(), FbError> {
self.conn
.as_mut()
.map(|conn| conn.transaction_operation(tr_handle, op))
.unwrap_or_else(err_client_not_connected)
}
fn exec_immediate(
&mut self,
_db_handle: &mut Self::DbHandle,
tr_handle: &mut Self::TrHandle,
dialect: Dialect,
sql: &str,
) -> Result<(), FbError> {
self.conn
.as_mut()
.map(|conn| conn.exec_immediate(tr_handle, dialect, sql))
.unwrap_or_else(err_client_not_connected)
}
fn prepare_statement(
&mut self,
db_handle: &mut Self::DbHandle,
tr_handle: &mut Self::TrHandle,
dialect: Dialect,
sql: &str,
) -> Result<(StmtType, Self::StmtHandle), FbError> {
self.conn
.as_mut()
.map(|conn| conn.prepare_statement(db_handle, tr_handle, dialect, sql))
.unwrap_or_else(err_client_not_connected)
}
fn free_statement(
&mut self,
stmt_handle: &mut Self::StmtHandle,
op: FreeStmtOp,
) -> Result<(), FbError> {
self.conn
.as_mut()
.map(|conn| conn.free_statement(stmt_handle, op))
.unwrap_or_else(err_client_not_connected)
}
fn execute(
&mut self,
_db_handle: &mut Self::DbHandle,
tr_handle: &mut Self::TrHandle,
stmt_handle: &mut Self::StmtHandle,
params: Vec<SqlType>,
) -> Result<usize, FbError> {
self.conn
.as_mut()
.map(|conn| conn.execute(tr_handle, stmt_handle, ¶ms))
.unwrap_or_else(err_client_not_connected)
}
fn execute2(
&mut self,
_db_handle: &mut Self::DbHandle,
tr_handle: &mut Self::TrHandle,
stmt_handle: &mut Self::StmtHandle,
params: Vec<SqlType>,
) -> Result<Vec<Column>, FbError> {
self.conn
.as_mut()
.map(|conn| conn.execute2(tr_handle, stmt_handle, ¶ms))
.unwrap_or_else(err_client_not_connected)
}
fn fetch(
&mut self,
_db_handle: &mut Self::DbHandle,
tr_handle: &mut Self::TrHandle,
stmt_handle: &mut Self::StmtHandle,
) -> Result<Option<Vec<Column>>, FbError> {
self.conn
.as_mut()
.map(|conn| conn.fetch(tr_handle, stmt_handle))
.unwrap_or_else(err_client_not_connected)
}
}
fn err_client_not_connected<T>() -> Result<T, FbError> {
Err("Client not connected to the server, call `attach_database` to connect".into())
}
impl FirebirdWireConnection {
pub fn connect(
host: &str,
port: u16,
db_name: &str,
user: &str,
pass: &str,
charset: Charset,
) -> Result<Self, FbError> {
let socket = TcpStream::connect((host, port))?;
let username =
env::var("USER").unwrap_or_else(|_| env::var("USERNAME").unwrap_or_default());
let hostname = socket
.local_addr()
.map(|addr| addr.to_string())
.unwrap_or_default();
let mut socket = FbStream::Plain(socket);
let srp_key: [u8; 32] = rand::random();
let req = connect(db_name, false, user, &username, &hostname, &srp_key);
socket.write_all(&req)?;
socket.flush()?;
let mut buff = vec![0; BUFFER_LENGTH as usize * 2].into_boxed_slice();
let len = socket.read(&mut buff)?;
let mut resp = Bytes::copy_from_slice(&buff[..len]);
let ConnectionResponse {
version,
auth_plugin,
} = parse_accept(&mut resp)?;
if let Some(mut auth_plugin) = auth_plugin {
loop {
match auth_plugin.kind {
plugin @ AuthPluginType::Srp => {
let srp = SrpClient::<sha1::Sha1>::new(&srp_key, &SRP_GROUP);
if let Some(data) = auth_plugin.data {
socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
break;
} else {
socket.write_all(&cont_auth(
hex::encode(srp.get_a_pub()).as_bytes(),
plugin,
AuthPluginType::plugin_list(),
&[],
))?;
socket.flush()?;
let len = socket.read(&mut buff)?;
let mut resp = Bytes::copy_from_slice(&buff[..len]);
auth_plugin = parse_cont_auth(&mut resp)?;
}
}
plugin @ AuthPluginType::Srp256 => {
let srp = SrpClient::<sha2::Sha256>::new(&srp_key, &SRP_GROUP);
if let Some(data) = auth_plugin.data {
socket = srp_auth(socket, &mut buff, srp, plugin, user, pass, data)?;
break;
} else {
socket.write_all(&cont_auth(
&hex::encode(srp.get_a_pub()).as_bytes(),
plugin,
AuthPluginType::plugin_list(),
&[],
))?;
socket.flush()?;
let len = socket.read(&mut buff)?;
let mut resp = Bytes::copy_from_slice(&buff[..len]);
auth_plugin = parse_cont_auth(&mut resp)?;
}
}
}
}
}
Ok(Self {
socket,
version,
buff,
lazy_count: 0,
charset,
})
}
pub fn attach_database(
&mut self,
db_name: &str,
user: &str,
pass: &str,
) -> Result<DbHandle, FbError> {
self.socket.write_all(&attach(
db_name,
user,
pass,
self.version,
self.charset.clone(),
))?;
self.socket.flush()?;
let resp = self.read_response()?;
Ok(DbHandle(resp.handle))
}
pub fn detach_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
self.socket.write_all(&detach(db_handle.0))?;
self.socket.flush()?;
self.read_response()?;
Ok(())
}
pub fn drop_database(&mut self, db_handle: &mut DbHandle) -> Result<(), FbError> {
self.socket.write_all(&drop_database(db_handle.0))?;
self.socket.flush()?;
self.read_response()?;
Ok(())
}
pub fn begin_transaction(
&mut self,
db_handle: &mut DbHandle,
isolation_level: TrIsolationLevel,
) -> Result<TrHandle, FbError> {
let tpb = [ibase::isc_tpb_version3 as u8, isolation_level as u8];
self.socket.write_all(&transaction(db_handle.0, &tpb))?;
self.socket.flush()?;
let resp = self.read_response()?;
Ok(TrHandle(resp.handle))
}
pub fn transaction_operation(
&mut self,
tr_handle: &mut TrHandle,
op: TrOp,
) -> Result<(), FbError> {
self.socket
.write_all(&transaction_operation(tr_handle.0, op))?;
self.socket.flush()?;
self.read_response()?;
Ok(())
}
pub fn exec_immediate(
&mut self,
tr_handle: &mut TrHandle,
dialect: Dialect,
sql: &str,
) -> Result<(), FbError> {
self.socket.write_all(&exec_immediate(
tr_handle.0,
dialect as u32,
sql,
&self.charset,
)?)?;
self.socket.flush()?;
self.read_response()?;
Ok(())
}
pub fn prepare_statement(
&mut self,
db_handle: &mut DbHandle,
tr_handle: &mut TrHandle,
dialect: Dialect,
sql: &str,
) -> Result<(StmtType, StmtHandleData), FbError> {
self.socket.write_all(&allocate_statement(db_handle.0))?;
self.socket.write_all(&prepare_statement(
tr_handle.0,
u32::MAX,
dialect as u32,
sql,
&self.charset,
)?)?;
self.socket.flush()?;
let (mut op_code, mut resp) = self.read_packet()?;
for _ in 0..self.lazy_count {
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
self.lazy_count -= 1;
parse_response(&mut resp)?;
op_code = resp.get_u32()?;
}
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
let stmt_handle = StmtHandle(parse_response(&mut resp)?.handle);
let op_code = resp.get_u32()?;
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
let mut xsqlda = Vec::new();
let mut resp = parse_response(&mut resp)?;
let PrepareInfo {
stmt_type,
mut param_count,
mut truncated,
} = parse_xsqlda(&mut resp.data, &mut xsqlda)?;
while truncated {
let next_index = (xsqlda.len() as u16).to_le_bytes();
self.socket.write_all(&info_sql(
stmt_handle.0,
&[
&[
ibase::isc_info_sql_sqlda_start as u8,
2,
next_index[0],
next_index[1],
],
&XSQLDA_DESCRIBE_VARS[..],
]
.concat(),
))?;
self.socket.flush()?;
let mut data = self.read_response()?.data;
let parse_resp = parse_xsqlda(&mut data, &mut xsqlda)?;
truncated = parse_resp.truncated;
param_count = parse_resp.param_count;
}
for var in xsqlda.iter_mut() {
var.coerce()?;
}
let blr = xsqlda_to_blr(&xsqlda)?;
Ok((
stmt_type,
StmtHandleData {
handle: stmt_handle,
xsqlda,
blr,
param_count,
},
))
}
pub fn free_statement(
&mut self,
stmt_handle: &mut StmtHandleData,
op: FreeStmtOp,
) -> Result<(), FbError> {
self.socket
.write_all(&free_statement(stmt_handle.handle.0, op))?;
self.lazy_count += 1;
Ok(())
}
pub fn execute(
&mut self,
tr_handle: &mut TrHandle,
stmt_handle: &mut StmtHandleData,
params: &[SqlType],
) -> Result<usize, FbError> {
if params.len() != stmt_handle.param_count {
return Err(format!(
"Tried to execute a statement that has {} parameters while providing {}",
stmt_handle.param_count,
params.len()
)
.into());
}
let params = blr::params_to_blr(self, tr_handle, params)?;
self.socket.write_all(&execute(
tr_handle.0,
stmt_handle.handle.0,
¶ms.blr,
¶ms.values,
))?;
self.socket.flush()?;
self.read_response()?;
self.socket.write_all(&info_sql(
stmt_handle.handle.0,
&[ibase::isc_info_sql_records as u8],
))?;
self.socket.flush()?;
let mut data = self.read_response()?.data;
Ok(parse_info_sql_affected_rows(&mut data)?)
}
pub fn execute2(
&mut self,
tr_handle: &mut TrHandle,
stmt_handle: &mut StmtHandleData,
params: &[SqlType],
) -> Result<Vec<Column>, FbError> {
if params.len() != stmt_handle.param_count {
return Err(format!(
"Tried to execute a statement that has {} parameters while providing {}",
stmt_handle.param_count,
params.len()
)
.into());
}
let params = blr::params_to_blr(self, tr_handle, params)?;
self.socket.write_all(&execute2(
tr_handle.0,
stmt_handle.handle.0,
¶ms.blr,
¶ms.values,
&stmt_handle.blr,
))?;
self.socket.flush()?;
let (mut op_code, mut resp) = read_packet(&mut self.socket, &mut self.buff)?;
for _ in 0..self.lazy_count {
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
self.lazy_count -= 1;
parse_response(&mut resp)?;
op_code = resp.get_u32()?;
}
if op_code == WireOp::Response as u32 {
parse_response(&mut resp)?;
}
if op_code != WireOp::SqlResponse as u32 {
return err_conn_rejected(op_code);
}
let parsed_cols =
parse_sql_response(&mut resp, &stmt_handle.xsqlda, self.version, &self.charset)?;
parse_response(&mut resp)?;
let mut cols = Vec::with_capacity(parsed_cols.len());
for pc in parsed_cols {
cols.push(pc.into_column(self, tr_handle)?);
}
Ok(cols)
}
pub fn fetch(
&mut self,
tr_handle: &mut TrHandle,
stmt_handle: &mut StmtHandleData,
) -> Result<Option<Vec<Column>>, FbError> {
self.socket
.write_all(&fetch(stmt_handle.handle.0, &stmt_handle.blr))?;
self.socket.flush()?;
let (mut op_code, mut resp) = read_packet(&mut self.socket, &mut self.buff)?;
for _ in 0..self.lazy_count {
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
self.lazy_count -= 1;
parse_response(&mut resp)?;
op_code = resp.get_u32()?;
}
if op_code == WireOp::Response as u32 {
parse_response(&mut resp)?;
}
if op_code != WireOp::FetchResponse as u32 {
return err_conn_rejected(op_code);
}
if let Some(parsed_cols) =
parse_fetch_response(&mut resp, &stmt_handle.xsqlda, self.version, &self.charset)?
{
let mut cols = Vec::with_capacity(parsed_cols.len());
for pc in parsed_cols {
cols.push(pc.into_column(self, tr_handle)?);
}
Ok(Some(cols))
} else {
Ok(None)
}
}
pub fn create_blob(
&mut self,
tr_handle: &mut TrHandle,
) -> Result<(BlobHandle, BlobId), FbError> {
self.socket.write_all(&create_blob(tr_handle.0))?;
self.socket.flush()?;
let resp = self.read_response()?;
Ok((BlobHandle(resp.handle), BlobId(resp.object_id)))
}
pub fn put_segments(&mut self, blob_handle: BlobHandle, data: &[u8]) -> Result<(), FbError> {
for segment in data.chunks(crate::blr::MAX_DATA_LENGTH) {
self.socket
.write_all(&put_segment(blob_handle.0, segment))?;
self.socket.flush()?;
self.read_response()?;
}
Ok(())
}
pub fn open_blob(
&mut self,
tr_handle: &mut TrHandle,
blob_id: BlobId,
) -> Result<BlobHandle, FbError> {
self.socket.write_all(&open_blob(tr_handle.0, blob_id.0))?;
self.socket.flush()?;
let resp = self.read_response()?;
Ok(BlobHandle(resp.handle))
}
pub fn get_segment(&mut self, blob_handle: BlobHandle) -> Result<(Bytes, bool), FbError> {
self.socket.write_all(&get_segment(blob_handle.0))?;
self.socket.flush()?;
let mut blob_data = BytesMut::with_capacity(256);
let resp = self.read_response()?;
let mut data = resp.data;
loop {
if data.remaining() < 2 {
break;
}
let len = data.get_u16_le()? as usize;
if data.remaining() < len {
return err_invalid_response();
}
blob_data.put_slice(&data[..len]);
data.advance(len)?;
}
Ok((blob_data.freeze(), resp.handle == 2))
}
pub fn close_blob(&mut self, blob_handle: BlobHandle) -> Result<(), FbError> {
self.socket.write_all(&close_blob(blob_handle.0))?;
self.socket.flush()?;
self.read_response()?;
Ok(())
}
fn read_response(&mut self) -> Result<Response, FbError> {
read_response(&mut self.socket, &mut self.buff, &mut self.lazy_count)
}
fn read_packet(&mut self) -> Result<(u32, Bytes), FbError> {
read_packet(&mut self.socket, &mut self.buff)
}
}
fn read_response(
socket: &mut impl Read,
buff: &mut [u8],
lazy_count: &mut u32,
) -> Result<Response, FbError> {
let (mut op_code, mut resp) = read_packet(socket, buff)?;
for _ in 0..*lazy_count {
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
*lazy_count -= 1;
parse_response(&mut resp)?;
op_code = resp.get_u32()?;
}
if op_code != WireOp::Response as u32 {
return err_conn_rejected(op_code);
}
parse_response(&mut resp)
}
fn read_packet(socket: &mut impl Read, buff: &mut [u8]) -> Result<(u32, Bytes), FbError> {
let mut len = socket.read(buff)?;
let mut resp = BytesMut::from(&buff[..len]);
loop {
if len == buff.len() {
len = socket.read(buff)?;
resp.put_slice(&buff[..len]);
} else {
break;
}
}
let mut resp = resp.freeze();
let op_code = loop {
let op_code = resp.get_u32()?;
if op_code != WireOp::Dummy as u32 {
break op_code;
}
};
Ok((op_code, resp))
}
fn srp_auth<D>(
mut socket: FbStream,
buff: &mut [u8],
srp: SrpClient<D>,
plugin: AuthPluginType,
user: &str,
pass: &str,
data: SrpAuthData,
) -> Result<FbStream, FbError>
where
D: digest::Digest,
{
let private_key = srp_private_key::<sha1::Sha1>(user.as_bytes(), pass.as_bytes(), &data.salt);
let verifier = srp
.process_reply(user.as_bytes(), &data.salt, &private_key, &data.pub_key)
.map_err(|e| FbError::from(format!("Srp error: {}", e)))?;
let proof = hex::encode(verifier.get_proof());
socket.write_all(&cont_auth(
&proof.as_bytes(),
plugin,
AuthPluginType::plugin_list(),
&[],
))?;
socket.flush()?;
read_response(&mut socket, buff, &mut 0)?;
socket.write_all(&crypt("Arc4", "Symmetric"))?;
socket.flush()?;
socket = FbStream::Arc4(Arc4Stream::new(
match socket {
FbStream::Plain(s) => s,
_ => unreachable!("Stream was already encrypted!"),
},
&verifier.get_key(),
buff.len(),
));
read_response(&mut socket, buff, &mut 0)?;
Ok(socket)
}
#[derive(Debug, Clone, Copy)]
pub struct DbHandle(u32);
#[derive(Debug, Clone, Copy)]
pub struct TrHandle(u32);
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct StmtHandle(u32);
#[derive(Debug, Clone, Copy)]
pub struct BlobHandle(u32);
#[derive(Debug, Clone, Copy)]
pub struct BlobId(pub(crate) u64);
enum FbStream {
Plain(TcpStream),
Arc4(Arc4Stream<TcpStream>),
}
impl Read for FbStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
FbStream::Plain(s) => s.read(buf),
FbStream::Arc4(s) => s.read(buf),
}
}
}
impl Write for FbStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
FbStream::Plain(s) => s.write(buf),
FbStream::Arc4(s) => s.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
FbStream::Plain(s) => s.flush(),
FbStream::Arc4(s) => s.flush(),
}
}
}
#[test]
#[ignore]
fn connection_test() {
use rsfbclient_core::charset::UTF_8;
let db_name = "test.fdb";
let user = "SYSDBA";
let pass = "masterkey";
let mut conn =
FirebirdWireConnection::connect("127.0.0.1", 3050, db_name, user, pass, UTF_8).unwrap();
let mut db_handle = conn.attach_database(db_name, user, pass).unwrap();
let mut tr_handle = conn
.begin_transaction(&mut db_handle, TrIsolationLevel::Concurrency)
.unwrap();
let (stmt_type, mut stmt_handle) = conn
.prepare_statement(
&mut db_handle,
&mut tr_handle,
Dialect::D3,
"
SELECT
1, 'abcdefghij' as tst, rand(), CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP, -1, -2, -3, -4, -5, 1, 2, 3, 4, 5, 0 as last
FROM RDB$DATABASE where 1 = ?
",
)
.unwrap();
println!("Statement type: {:?}", stmt_type);
let params = match rsfbclient_core::IntoParams::to_params((1,)) {
rsfbclient_core::ParamsType::Positional(params) => params,
_ => unreachable!(),
};
conn.execute(&mut tr_handle, &mut stmt_handle, ¶ms)
.unwrap();
loop {
let resp = conn.fetch(&mut tr_handle, &mut stmt_handle).unwrap();
if resp.is_none() {
break;
}
println!("Fetch Resp: {:#?}", resp);
}
std::thread::sleep(std::time::Duration::from_millis(100));
}