1use std::io::{Read, Write};
4use std::net::{TcpStream, ToSocketAddrs};
5use std::time::Duration;
6
7use bytes::BytesMut;
8use kimberlite_types::{DataClass, Offset, Placement, StreamId, TenantId};
9use kimberlite_wire::{
10 AppendEventsRequest, CreateStreamRequest, ErrorCode, Frame, HandshakeRequest, PROTOCOL_VERSION,
11 QueryAtRequest, QueryParam, QueryRequest, QueryResponse, ReadEventsRequest, ReadEventsResponse,
12 Request, RequestId, RequestPayload, Response, ResponsePayload, SyncRequest,
13};
14
15use crate::error::{ClientError, ClientResult};
16
17#[derive(Debug, Clone)]
19pub struct ClientConfig {
20 pub read_timeout: Option<Duration>,
22 pub write_timeout: Option<Duration>,
24 pub buffer_size: usize,
26 pub auth_token: Option<String>,
28}
29
30impl Default for ClientConfig {
31 fn default() -> Self {
32 Self {
33 read_timeout: Some(Duration::from_secs(30)),
34 write_timeout: Some(Duration::from_secs(30)),
35 buffer_size: 64 * 1024,
36 auth_token: None,
37 }
38 }
39}
40
41pub struct Client {
61 stream: TcpStream,
62 tenant_id: TenantId,
63 next_request_id: u64,
64 read_buf: BytesMut,
65 config: ClientConfig,
66}
67
68impl Client {
69 pub fn connect(
71 addr: impl ToSocketAddrs,
72 tenant_id: TenantId,
73 config: ClientConfig,
74 ) -> ClientResult<Self> {
75 let stream = TcpStream::connect(addr)?;
76 stream.set_read_timeout(config.read_timeout)?;
77 stream.set_write_timeout(config.write_timeout)?;
78
79 let mut client = Self {
80 stream,
81 tenant_id,
82 next_request_id: 1,
83 read_buf: BytesMut::with_capacity(config.buffer_size),
84 config,
85 };
86
87 client.handshake()?;
89
90 Ok(client)
91 }
92
93 fn handshake(&mut self) -> ClientResult<()> {
95 let response = self.send_request(RequestPayload::Handshake(HandshakeRequest {
96 client_version: PROTOCOL_VERSION,
97 auth_token: self.config.auth_token.clone(),
98 }))?;
99
100 match response.payload {
101 ResponsePayload::Handshake(h) => {
102 if h.server_version != PROTOCOL_VERSION {
103 return Err(ClientError::HandshakeFailed(format!(
104 "protocol version mismatch: client {}, server {}",
105 PROTOCOL_VERSION, h.server_version
106 )));
107 }
108 Ok(())
109 }
110 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
111 _ => Err(ClientError::UnexpectedResponse {
112 expected: "Handshake".to_string(),
113 actual: format!("{:?}", response.payload),
114 }),
115 }
116 }
117
118 pub fn create_stream(&mut self, name: &str, data_class: DataClass) -> ClientResult<StreamId> {
120 self.create_stream_with_placement(name, data_class, Placement::Global)
121 }
122
123 pub fn create_stream_with_placement(
125 &mut self,
126 name: &str,
127 data_class: DataClass,
128 placement: Placement,
129 ) -> ClientResult<StreamId> {
130 let response = self.send_request(RequestPayload::CreateStream(CreateStreamRequest {
131 name: name.to_string(),
132 data_class,
133 placement,
134 }))?;
135
136 match response.payload {
137 ResponsePayload::CreateStream(r) => Ok(r.stream_id),
138 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
139 _ => Err(ClientError::UnexpectedResponse {
140 expected: "CreateStream".to_string(),
141 actual: format!("{:?}", response.payload),
142 }),
143 }
144 }
145
146 pub fn append(&mut self, stream_id: StreamId, events: Vec<Vec<u8>>) -> ClientResult<Offset> {
150 let response = self.send_request(RequestPayload::AppendEvents(AppendEventsRequest {
151 stream_id,
152 events,
153 }))?;
154
155 match response.payload {
156 ResponsePayload::AppendEvents(r) => Ok(r.first_offset),
157 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
158 _ => Err(ClientError::UnexpectedResponse {
159 expected: "AppendEvents".to_string(),
160 actual: format!("{:?}", response.payload),
161 }),
162 }
163 }
164
165 pub fn query(&mut self, sql: &str, params: &[QueryParam]) -> ClientResult<QueryResponse> {
167 let response = self.send_request(RequestPayload::Query(QueryRequest {
168 sql: sql.to_string(),
169 params: params.to_vec(),
170 }))?;
171
172 match response.payload {
173 ResponsePayload::Query(r) => Ok(r),
174 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
175 _ => Err(ClientError::UnexpectedResponse {
176 expected: "Query".to_string(),
177 actual: format!("{:?}", response.payload),
178 }),
179 }
180 }
181
182 pub fn query_at(
184 &mut self,
185 sql: &str,
186 params: &[QueryParam],
187 position: Offset,
188 ) -> ClientResult<QueryResponse> {
189 let response = self.send_request(RequestPayload::QueryAt(QueryAtRequest {
190 sql: sql.to_string(),
191 params: params.to_vec(),
192 position,
193 }))?;
194
195 match response.payload {
196 ResponsePayload::QueryAt(r) => Ok(r),
197 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
198 _ => Err(ClientError::UnexpectedResponse {
199 expected: "QueryAt".to_string(),
200 actual: format!("{:?}", response.payload),
201 }),
202 }
203 }
204
205 pub fn read_events(
207 &mut self,
208 stream_id: StreamId,
209 from_offset: Offset,
210 max_bytes: u64,
211 ) -> ClientResult<ReadEventsResponse> {
212 let response = self.send_request(RequestPayload::ReadEvents(ReadEventsRequest {
213 stream_id,
214 from_offset,
215 max_bytes,
216 }))?;
217
218 match response.payload {
219 ResponsePayload::ReadEvents(r) => Ok(r),
220 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
221 _ => Err(ClientError::UnexpectedResponse {
222 expected: "ReadEvents".to_string(),
223 actual: format!("{:?}", response.payload),
224 }),
225 }
226 }
227
228 pub fn sync(&mut self) -> ClientResult<()> {
230 let response = self.send_request(RequestPayload::Sync(SyncRequest {}))?;
231
232 match response.payload {
233 ResponsePayload::Sync(r) => {
234 if r.success {
235 Ok(())
236 } else {
237 Err(ClientError::server(ErrorCode::InternalError, "sync failed"))
238 }
239 }
240 ResponsePayload::Error(e) => Err(ClientError::server(e.code, e.message)),
241 _ => Err(ClientError::UnexpectedResponse {
242 expected: "Sync".to_string(),
243 actual: format!("{:?}", response.payload),
244 }),
245 }
246 }
247
248 pub fn tenant_id(&self) -> TenantId {
250 self.tenant_id
251 }
252
253 fn send_request(&mut self, payload: RequestPayload) -> ClientResult<Response> {
255 let request_id = RequestId::new(self.next_request_id);
256 self.next_request_id += 1;
257
258 let request = Request::new(request_id, self.tenant_id, payload);
259
260 let frame = request.to_frame()?;
262 let mut write_buf = BytesMut::new();
263 frame.encode(&mut write_buf);
264 self.stream.write_all(&write_buf)?;
265 self.stream.flush()?;
266
267 let response = self.read_response()?;
269
270 if response.request_id.0 != request_id.0 {
272 return Err(ClientError::ResponseMismatch {
273 expected: request_id.0,
274 received: response.request_id.0,
275 });
276 }
277
278 Ok(response)
279 }
280
281 fn read_response(&mut self) -> ClientResult<Response> {
283 loop {
284 if let Some(frame) = Frame::decode(&mut self.read_buf)? {
286 let response = Response::from_frame(&frame)?;
287 return Ok(response);
288 }
289
290 let mut temp_buf = [0u8; 4096];
292 let n = self.stream.read(&mut temp_buf)?;
293 if n == 0 {
294 return Err(ClientError::Connection(std::io::Error::new(
295 std::io::ErrorKind::UnexpectedEof,
296 "server closed connection",
297 )));
298 }
299 self.read_buf.extend_from_slice(&temp_buf[..n]);
300
301 if self.read_buf.len() > self.config.buffer_size * 2 {
303 return Err(ClientError::Connection(std::io::Error::new(
304 std::io::ErrorKind::InvalidData,
305 "response too large",
306 )));
307 }
308 }
309 }
310}
311
312impl std::fmt::Debug for Client {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("Client")
315 .field("tenant_id", &self.tenant_id)
316 .field("next_request_id", &self.next_request_id)
317 .finish_non_exhaustive()
318 }
319}