Skip to main content

kimberlite_client/
client.rs

1//! RPC client for `Kimberlite`.
2
3use std::io::{Read, Write};
4use std::net::{TcpStream, ToSocketAddrs};
5use std::time::Duration;
6
7use bytes::BytesMut;
8use kimberlite_types::{DataClass, Offset, Placement, StreamId, TenantId};
9use kimberlite_wire::{
10    AppendEventsRequest, CreateStreamRequest, ErrorCode, Frame, HandshakeRequest, PROTOCOL_VERSION,
11    QueryAtRequest, QueryParam, QueryRequest, QueryResponse, ReadEventsRequest, ReadEventsResponse,
12    Request, RequestId, RequestPayload, Response, ResponsePayload, SyncRequest,
13};
14
15use crate::error::{ClientError, ClientResult};
16
17/// Configuration for the client.
18#[derive(Debug, Clone)]
19pub struct ClientConfig {
20    /// Read timeout.
21    pub read_timeout: Option<Duration>,
22    /// Write timeout.
23    pub write_timeout: Option<Duration>,
24    /// Buffer size for reads.
25    pub buffer_size: usize,
26    /// Authentication token.
27    pub auth_token: Option<String>,
28}
29
30impl Default for ClientConfig {
31    fn default() -> Self {
32        Self {
33            read_timeout: Some(Duration::from_secs(30)),
34            write_timeout: Some(Duration::from_secs(30)),
35            buffer_size: 64 * 1024,
36            auth_token: None,
37        }
38    }
39}
40
41/// RPC client for `Kimberlite`.
42///
43/// This client uses synchronous I/O to communicate with a `Kimberlite` server
44/// using the binary wire protocol.
45///
46/// # Example
47///
48/// ```ignore
49/// use kimberlite_client::{Client, ClientConfig};
50/// use kimberlite_types::{DataClass, TenantId};
51///
52/// let mut client = Client::connect("127.0.0.1:5432", TenantId::new(1), ClientConfig::default())?;
53///
54/// // Create a stream
55/// let stream_id = client.create_stream("events", DataClass::NonPHI)?;
56///
57/// // Append events
58/// let offset = client.append(stream_id, vec![b"event1".to_vec(), b"event2".to_vec()])?;
59/// ```
60pub struct Client {
61    stream: TcpStream,
62    tenant_id: TenantId,
63    next_request_id: u64,
64    read_buf: BytesMut,
65    config: ClientConfig,
66}
67
68impl Client {
69    /// Connects to a `Kimberlite` server.
70    pub fn connect(
71        addr: impl ToSocketAddrs,
72        tenant_id: TenantId,
73        config: ClientConfig,
74    ) -> ClientResult<Self> {
75        let stream = TcpStream::connect(addr)?;
76        stream.set_read_timeout(config.read_timeout)?;
77        stream.set_write_timeout(config.write_timeout)?;
78
79        let mut client = Self {
80            stream,
81            tenant_id,
82            next_request_id: 1,
83            read_buf: BytesMut::with_capacity(config.buffer_size),
84            config,
85        };
86
87        // Perform handshake
88        client.handshake()?;
89
90        Ok(client)
91    }
92
93    /// Performs the handshake with the server.
94    fn handshake(&mut self) -> ClientResult<()> {
95        let response = self.send_request(RequestPayload::Handshake(HandshakeRequest {
96            client_version: PROTOCOL_VERSION,
97            auth_token: self.config.auth_token.clone(),
98        }))?;
99
100        match response.payload {
101            ResponsePayload::Handshake(h) => {
102                if h.server_version != PROTOCOL_VERSION {
103                    return Err(ClientError::HandshakeFailed(format!(
104                        "protocol version mismatch: client {}, server {}",
105                        PROTOCOL_VERSION, h.server_version
106                    )));
107                }
108                Ok(())
109            }
110            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
111            _ => Err(ClientError::UnexpectedResponse {
112                expected: "Handshake".to_string(),
113                actual: format!("{:?}", response.payload),
114            }),
115        }
116    }
117
118    /// Creates a new stream.
119    pub fn create_stream(&mut self, name: &str, data_class: DataClass) -> ClientResult<StreamId> {
120        self.create_stream_with_placement(name, data_class, Placement::Global)
121    }
122
123    /// Creates a new stream with a specific placement policy.
124    pub fn create_stream_with_placement(
125        &mut self,
126        name: &str,
127        data_class: DataClass,
128        placement: Placement,
129    ) -> ClientResult<StreamId> {
130        let response = self.send_request(RequestPayload::CreateStream(CreateStreamRequest {
131            name: name.to_string(),
132            data_class,
133            placement,
134        }))?;
135
136        match response.payload {
137            ResponsePayload::CreateStream(r) => Ok(r.stream_id),
138            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
139            _ => Err(ClientError::UnexpectedResponse {
140                expected: "CreateStream".to_string(),
141                actual: format!("{:?}", response.payload),
142            }),
143        }
144    }
145
146    /// Appends events to a stream.
147    ///
148    /// Returns the offset of the first appended event.
149    pub fn append(&mut self, stream_id: StreamId, events: Vec<Vec<u8>>) -> ClientResult<Offset> {
150        let response = self.send_request(RequestPayload::AppendEvents(AppendEventsRequest {
151            stream_id,
152            events,
153        }))?;
154
155        match response.payload {
156            ResponsePayload::AppendEvents(r) => Ok(r.first_offset),
157            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
158            _ => Err(ClientError::UnexpectedResponse {
159                expected: "AppendEvents".to_string(),
160                actual: format!("{:?}", response.payload),
161            }),
162        }
163    }
164
165    /// Executes a SQL query.
166    pub fn query(&mut self, sql: &str, params: &[QueryParam]) -> ClientResult<QueryResponse> {
167        let response = self.send_request(RequestPayload::Query(QueryRequest {
168            sql: sql.to_string(),
169            params: params.to_vec(),
170        }))?;
171
172        match response.payload {
173            ResponsePayload::Query(r) => Ok(r),
174            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
175            _ => Err(ClientError::UnexpectedResponse {
176                expected: "Query".to_string(),
177                actual: format!("{:?}", response.payload),
178            }),
179        }
180    }
181
182    /// Executes a SQL query at a specific position.
183    pub fn query_at(
184        &mut self,
185        sql: &str,
186        params: &[QueryParam],
187        position: Offset,
188    ) -> ClientResult<QueryResponse> {
189        let response = self.send_request(RequestPayload::QueryAt(QueryAtRequest {
190            sql: sql.to_string(),
191            params: params.to_vec(),
192            position,
193        }))?;
194
195        match response.payload {
196            ResponsePayload::QueryAt(r) => Ok(r),
197            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
198            _ => Err(ClientError::UnexpectedResponse {
199                expected: "QueryAt".to_string(),
200                actual: format!("{:?}", response.payload),
201            }),
202        }
203    }
204
205    /// Reads events from a stream.
206    pub fn read_events(
207        &mut self,
208        stream_id: StreamId,
209        from_offset: Offset,
210        max_bytes: u64,
211    ) -> ClientResult<ReadEventsResponse> {
212        let response = self.send_request(RequestPayload::ReadEvents(ReadEventsRequest {
213            stream_id,
214            from_offset,
215            max_bytes,
216        }))?;
217
218        match response.payload {
219            ResponsePayload::ReadEvents(r) => Ok(r),
220            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
221            _ => Err(ClientError::UnexpectedResponse {
222                expected: "ReadEvents".to_string(),
223                actual: format!("{:?}", response.payload),
224            }),
225        }
226    }
227
228    /// Syncs all data to disk.
229    pub fn sync(&mut self) -> ClientResult<()> {
230        let response = self.send_request(RequestPayload::Sync(SyncRequest {}))?;
231
232        match response.payload {
233            ResponsePayload::Sync(r) => {
234                if r.success {
235                    Ok(())
236                } else {
237                    Err(ClientError::server(ErrorCode::InternalError, "sync failed"))
238                }
239            }
240            ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
241            _ => Err(ClientError::UnexpectedResponse {
242                expected: "Sync".to_string(),
243                actual: format!("{:?}", response.payload),
244            }),
245        }
246    }
247
248    /// Returns the tenant ID for this client.
249    pub fn tenant_id(&self) -> TenantId {
250        self.tenant_id
251    }
252
253    /// Sends a request and waits for the response.
254    fn send_request(&mut self, payload: RequestPayload) -> ClientResult<Response> {
255        let request_id = RequestId::new(self.next_request_id);
256        self.next_request_id += 1;
257
258        let request = Request::new(request_id, self.tenant_id, payload);
259
260        // Encode and send the request
261        let frame = request.to_frame()?;
262        let mut write_buf = BytesMut::new();
263        frame.encode(&mut write_buf);
264        self.stream.write_all(&write_buf)?;
265        self.stream.flush()?;
266
267        // Read the response
268        let response = self.read_response()?;
269
270        // Verify request ID matches
271        if response.request_id.0 != request_id.0 {
272            return Err(ClientError::ResponseMismatch {
273                expected: request_id.0,
274                received: response.request_id.0,
275            });
276        }
277
278        Ok(response)
279    }
280
281    /// Reads a response from the server.
282    fn read_response(&mut self) -> ClientResult<Response> {
283        loop {
284            // Try to decode a frame from the buffer
285            if let Some(frame) = Frame::decode(&mut self.read_buf)? {
286                let response = Response::from_frame(&frame)?;
287                return Ok(response);
288            }
289
290            // Need more data - read from socket
291            let mut temp_buf = [0u8; 4096];
292            let n = self.stream.read(&mut temp_buf)?;
293            if n == 0 {
294                return Err(ClientError::Connection(std::io::Error::new(
295                    std::io::ErrorKind::UnexpectedEof,
296                    "server closed connection",
297                )));
298            }
299            self.read_buf.extend_from_slice(&temp_buf[..n]);
300
301            // Check for buffer overflow (simple DoS protection)
302            if self.read_buf.len() > self.config.buffer_size * 2 {
303                return Err(ClientError::Connection(std::io::Error::new(
304                    std::io::ErrorKind::InvalidData,
305                    "response too large",
306                )));
307            }
308        }
309    }
310}
311
312impl std::fmt::Debug for Client {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        f.debug_struct("Client")
315            .field("tenant_id", &self.tenant_id)
316            .field("next_request_id", &self.next_request_id)
317            .finish_non_exhaustive()
318    }
319}