1use crate::error::{ClientError, Result};
2use crate::tls::TlsClientConfig;
3use bytes::Bytes;
4use lnc_network::{
5 ControlCommand, Frame, FrameType, LWP_HEADER_SIZE, TlsConnector, encode_frame, parse_frame,
6};
7use std::net::SocketAddr;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::task::{Context, Poll};
11use std::time::Duration;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
13use tokio::net::TcpStream;
14use tokio::net::lookup_host;
15use tracing::{debug, trace, warn};
16
17#[allow(clippy::large_enum_variant)]
19pub enum ClientStream {
20 Tcp(TcpStream),
22 Tls(tokio_rustls::client::TlsStream<TcpStream>),
24}
25
26impl AsyncRead for ClientStream {
27 fn poll_read(
28 self: Pin<&mut Self>,
29 cx: &mut Context<'_>,
30 buf: &mut ReadBuf<'_>,
31 ) -> Poll<std::io::Result<()>> {
32 match self.get_mut() {
33 ClientStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
34 ClientStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
35 }
36 }
37}
38
39impl AsyncWrite for ClientStream {
40 fn poll_write(
41 self: Pin<&mut Self>,
42 cx: &mut Context<'_>,
43 buf: &[u8],
44 ) -> Poll<std::io::Result<usize>> {
45 match self.get_mut() {
46 ClientStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
47 ClientStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
48 }
49 }
50
51 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
52 match self.get_mut() {
53 ClientStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
54 ClientStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
55 }
56 }
57
58 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
59 match self.get_mut() {
60 ClientStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
61 ClientStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
62 }
63 }
64}
65
66fn extract_error_message(frame: &Frame) -> String {
68 frame
69 .payload
70 .as_ref()
71 .map(|p| String::from_utf8_lossy(p).to_string())
72 .unwrap_or_else(|| "Unknown error".to_string())
73}
74
75#[allow(dead_code)] fn expect_frame_type(frame: Frame, expected: ControlCommand, expected_name: &str) -> Result<Frame> {
78 match frame.frame_type {
79 FrameType::Control(cmd) if cmd == expected => Ok(frame),
80 FrameType::Control(ControlCommand::ErrorResponse) => {
81 Err(ClientError::ServerError(extract_error_message(&frame)))
82 },
83 other => Err(ClientError::InvalidResponse(format!(
84 "Expected {}, got {:?}",
85 expected_name, other
86 ))),
87 }
88}
89
90fn expect_success_response(frame: Frame) -> Result<()> {
92 match frame.frame_type {
93 FrameType::Control(ControlCommand::TopicResponse) => Ok(()),
94 FrameType::Control(ControlCommand::ErrorResponse) => {
95 Err(ClientError::ServerError(extract_error_message(&frame)))
96 },
97 other => Err(ClientError::InvalidResponse(format!(
98 "Expected TopicResponse, got {:?}",
99 other
100 ))),
101 }
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct AuthConfig {
107 pub mtls_enabled: bool,
109 pub client_cert_path: Option<String>,
111 pub client_key_path: Option<String>,
113}
114
115#[derive(Debug, Clone, Default)]
117pub struct RetentionInfo {
118 pub max_age_secs: u64,
120 pub max_bytes: u64,
122}
123
124#[derive(Debug, Clone)]
126pub struct TopicInfo {
127 pub id: u32,
129 pub name: String,
131 pub created_at: u64,
133 pub retention: Option<RetentionInfo>,
135}
136
137#[derive(Debug, Clone)]
139pub struct FetchResult {
140 pub data: Bytes,
142 pub next_offset: u64,
144 pub bytes_returned: u32,
146 pub record_count: u32,
148}
149
150#[derive(Debug, Clone)]
152pub struct SubscribeResult {
153 pub consumer_id: u64,
155 pub start_offset: u64,
157}
158
159#[derive(Debug, Clone)]
161pub struct CommitResult {
162 pub consumer_id: u64,
164 pub committed_offset: u64,
166}
167
168#[derive(Debug, Clone)]
170pub struct ClusterStatus {
171 pub node_id: u16,
172 pub is_leader: bool,
173 pub leader_id: Option<u16>,
174 pub current_term: u64,
175 pub node_count: usize,
176 pub healthy_nodes: usize,
177 pub quorum_available: bool,
178 pub peer_states: std::collections::HashMap<u16, String>,
179}
180
181#[derive(Debug, Clone)]
183pub struct ClientConfig {
184 pub addr: String,
186 pub connect_timeout: Duration,
188 pub read_timeout: Duration,
190 pub write_timeout: Duration,
192 pub keepalive_interval: Duration,
194 pub tls: Option<TlsClientConfig>,
196}
197
198impl Default for ClientConfig {
199 fn default() -> Self {
200 Self {
201 addr: "127.0.0.1:1992".to_string(),
202 connect_timeout: Duration::from_secs(10),
203 read_timeout: Duration::from_secs(30),
204 write_timeout: Duration::from_secs(10),
205 keepalive_interval: Duration::from_secs(10),
206 tls: None,
207 }
208 }
209}
210
211impl ClientConfig {
212 pub fn new(addr: impl Into<String>) -> Self {
218 Self {
219 addr: addr.into(),
220 ..Default::default()
221 }
222 }
223
224 pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
226 self.tls = Some(tls_config);
227 self
228 }
229
230 pub fn is_tls_enabled(&self) -> bool {
232 self.tls.is_some()
233 }
234}
235
236pub struct LanceClient {
240 stream: ClientStream,
241 config: ClientConfig,
242 batch_id: AtomicU64,
243 read_buffer: Vec<u8>,
244 read_offset: usize,
245}
246
247impl LanceClient {
248 async fn resolve_address(addr: &str) -> Result<SocketAddr> {
252 if let Ok(socket_addr) = addr.parse::<SocketAddr>() {
254 return Ok(socket_addr);
255 }
256
257 let mut addrs = lookup_host(addr).await.map_err(|e| {
259 ClientError::ProtocolError(format!("DNS resolution failed for '{}': {}", addr, e))
260 })?;
261
262 addrs
263 .next()
264 .ok_or_else(|| ClientError::ProtocolError(format!("No addresses found for '{}'", addr)))
265 }
266
267 pub async fn connect(config: ClientConfig) -> Result<Self> {
272 if let Some(ref tls_config) = config.tls {
274 return Self::connect_tls(config.clone(), tls_config.clone()).await;
275 }
276
277 debug!(addr = %config.addr, "Connecting to LANCE server");
278
279 let socket_addr = Self::resolve_address(&config.addr).await?;
281 debug!(resolved_addr = %socket_addr, "Resolved server address");
282
283 let stream = tokio::time::timeout(config.connect_timeout, TcpStream::connect(socket_addr))
284 .await
285 .map_err(|_| ClientError::Timeout)?
286 .map_err(ClientError::ConnectionFailed)?;
287
288 stream.set_nodelay(true)?;
289
290 debug!(addr = %config.addr, "Connected to LANCE server");
291
292 Ok(Self {
293 stream: ClientStream::Tcp(stream),
294 config,
295 batch_id: AtomicU64::new(0),
296 read_buffer: vec![0u8; 64 * 1024],
297 read_offset: 0,
298 })
299 }
300
301 pub async fn connect_tls(config: ClientConfig, tls_config: TlsClientConfig) -> Result<Self> {
318 debug!(addr = %config.addr, "Connecting to LANCE server with TLS");
319
320 let socket_addr = Self::resolve_address(&config.addr).await?;
322 debug!(resolved_addr = %socket_addr, "Resolved server address");
323
324 let tcp_stream =
326 tokio::time::timeout(config.connect_timeout, TcpStream::connect(socket_addr))
327 .await
328 .map_err(|_| ClientError::Timeout)?
329 .map_err(ClientError::ConnectionFailed)?;
330
331 tcp_stream.set_nodelay(true)?;
332
333 let network_config = tls_config.to_network_config();
335 let connector =
336 TlsConnector::new(network_config).map_err(|e| ClientError::TlsError(e.to_string()))?;
337
338 let server_name = tls_config.server_name.unwrap_or_else(|| {
340 config
342 .addr
343 .rsplit_once(':')
344 .map(|(host, _)| host.to_string())
345 .unwrap_or_else(|| socket_addr.ip().to_string())
346 });
347
348 let tls_stream = connector
350 .connect(&server_name, tcp_stream)
351 .await
352 .map_err(|e| ClientError::TlsError(e.to_string()))?;
353
354 debug!(addr = %config.addr, "TLS connection established");
355
356 Ok(Self {
357 stream: ClientStream::Tls(tls_stream),
358 config,
359 batch_id: AtomicU64::new(0),
360 read_buffer: vec![0u8; 64 * 1024],
361 read_offset: 0,
362 })
363 }
364
365 pub async fn connect_to(addr: &str) -> Result<Self> {
370 Self::connect(ClientConfig::new(addr)).await
371 }
372
373 pub async fn connect_tls_to(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
378 Self::connect_tls(ClientConfig::new(addr), tls_config).await
379 }
380
381 fn next_batch_id(&self) -> u64 {
382 self.batch_id.fetch_add(1, Ordering::SeqCst) + 1
383 }
384
385 pub async fn send_ingest(&mut self, payload: Bytes, record_count: u32) -> Result<u64> {
387 self.send_ingest_to_topic(0, payload, record_count, None)
388 .await
389 }
390
391 pub async fn send_ingest_to_topic(
393 &mut self,
394 topic_id: u32,
395 payload: Bytes,
396 record_count: u32,
397 _auth_config: Option<&AuthConfig>,
398 ) -> Result<u64> {
399 let batch_id = self.next_batch_id();
400 let timestamp_ns = std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .map(|d| d.as_nanos() as u64)
403 .unwrap_or(0);
404
405 let frame =
406 Frame::new_ingest_with_topic(batch_id, timestamp_ns, record_count, payload, topic_id);
407 let frame_bytes = encode_frame(&frame);
408
409 trace!(
410 batch_id,
411 topic_id,
412 payload_len = frame.payload_length(),
413 "Sending ingest frame"
414 );
415
416 tokio::time::timeout(
417 self.config.write_timeout,
418 self.stream.write_all(&frame_bytes),
419 )
420 .await
421 .map_err(|_| ClientError::Timeout)??;
422
423 Ok(batch_id)
424 }
425
426 pub async fn send_ingest_sync(&mut self, payload: Bytes, record_count: u32) -> Result<u64> {
428 self.send_ingest_to_topic_sync(0, payload, record_count, None)
429 .await
430 }
431
432 pub async fn send_ingest_to_topic_sync(
434 &mut self,
435 topic_id: u32,
436 payload: Bytes,
437 record_count: u32,
438 auth_config: Option<&AuthConfig>,
439 ) -> Result<u64> {
440 let batch_id = self
441 .send_ingest_to_topic(topic_id, payload, record_count, auth_config)
442 .await?;
443 self.wait_for_ack(batch_id).await
444 }
445
446 async fn wait_for_ack(&mut self, expected_batch_id: u64) -> Result<u64> {
447 let frame = self.recv_frame().await?;
448
449 match frame.frame_type {
450 FrameType::Ack => {
451 let acked_id = frame.batch_id();
452 if acked_id != expected_batch_id {
453 return Err(ClientError::InvalidResponse(format!(
454 "Ack batch_id mismatch: sent {}, received {}",
455 expected_batch_id, acked_id
456 )));
457 }
458 trace!(batch_id = acked_id, "Received ack");
459 Ok(acked_id)
460 },
461 FrameType::Control(ControlCommand::ErrorResponse) => {
462 let error_msg = frame
463 .payload
464 .map(|p| String::from_utf8_lossy(&p).to_string())
465 .unwrap_or_else(|| "Unknown error".to_string());
466 Err(ClientError::ServerError(error_msg))
467 },
468 FrameType::Backpressure => {
469 warn!("Server signaled backpressure");
470 Err(ClientError::ServerBackpressure)
471 },
472 other => Err(ClientError::InvalidResponse(format!(
473 "Expected Ack, got {:?}",
474 other
475 ))),
476 }
477 }
478
479 pub async fn recv_ack(&mut self) -> Result<u64> {
481 let frame = self.recv_frame().await?;
482
483 match frame.frame_type {
484 FrameType::Ack => {
485 trace!(batch_id = frame.batch_id(), "Received ack");
486 Ok(frame.batch_id())
487 },
488 FrameType::Backpressure => {
489 warn!("Server signaled backpressure");
490 Err(ClientError::ServerBackpressure)
491 },
492 other => Err(ClientError::InvalidResponse(format!(
493 "Expected Ack, got {:?}",
494 other
495 ))),
496 }
497 }
498
499 pub async fn send_keepalive(&mut self) -> Result<()> {
501 let frame = Frame::new_keepalive();
502 let frame_bytes = encode_frame(&frame);
503
504 trace!("Sending keepalive");
505
506 tokio::time::timeout(
507 self.config.write_timeout,
508 self.stream.write_all(&frame_bytes),
509 )
510 .await
511 .map_err(|_| ClientError::Timeout)??;
512
513 Ok(())
514 }
515
516 pub async fn recv_keepalive(&mut self) -> Result<()> {
518 let frame = self.recv_frame().await?;
519
520 match frame.frame_type {
521 FrameType::Keepalive => {
522 trace!("Received keepalive response");
523 Ok(())
524 },
525 other => Err(ClientError::InvalidResponse(format!(
526 "Expected Keepalive, got {:?}",
527 other
528 ))),
529 }
530 }
531
532 pub async fn ping(&mut self) -> Result<Duration> {
534 let start = std::time::Instant::now();
535 self.send_keepalive().await?;
536 self.recv_keepalive().await?;
537 Ok(start.elapsed())
538 }
539
540 pub async fn create_topic(&mut self, name: &str) -> Result<TopicInfo> {
542 let frame = Frame::new_create_topic(name);
543 let frame_bytes = encode_frame(&frame);
544
545 trace!(topic_name = %name, "Creating topic");
546
547 tokio::time::timeout(
548 self.config.write_timeout,
549 self.stream.write_all(&frame_bytes),
550 )
551 .await
552 .map_err(|_| ClientError::Timeout)??;
553
554 let response = self.recv_frame().await?;
555 self.parse_topic_response(response)
556 }
557
558 pub async fn list_topics(&mut self) -> Result<Vec<TopicInfo>> {
560 let frame = Frame::new_list_topics();
561 let frame_bytes = encode_frame(&frame);
562
563 trace!("Listing topics");
564
565 tokio::time::timeout(
566 self.config.write_timeout,
567 self.stream.write_all(&frame_bytes),
568 )
569 .await
570 .map_err(|_| ClientError::Timeout)??;
571
572 let response = self.recv_frame().await?;
573 self.parse_topic_list_response(response)
574 }
575
576 pub async fn get_topic(&mut self, topic_id: u32) -> Result<TopicInfo> {
578 let frame = Frame::new_get_topic(topic_id);
579 let frame_bytes = encode_frame(&frame);
580
581 trace!(topic_id, "Getting topic");
582
583 tokio::time::timeout(
584 self.config.write_timeout,
585 self.stream.write_all(&frame_bytes),
586 )
587 .await
588 .map_err(|_| ClientError::Timeout)??;
589
590 let response = self.recv_frame().await?;
591 self.parse_topic_response(response)
592 }
593
594 pub async fn delete_topic(&mut self, topic_id: u32) -> Result<()> {
596 let frame = Frame::new_delete_topic(topic_id);
597 let frame_bytes = encode_frame(&frame);
598
599 trace!(topic_id, "Deleting topic");
600
601 tokio::time::timeout(
602 self.config.write_timeout,
603 self.stream.write_all(&frame_bytes),
604 )
605 .await
606 .map_err(|_| ClientError::Timeout)??;
607
608 let response = self.recv_frame().await?;
609 self.parse_delete_response(response)
610 }
611
612 pub async fn set_retention(
619 &mut self,
620 topic_id: u32,
621 max_age_secs: u64,
622 max_bytes: u64,
623 ) -> Result<()> {
624 let frame = Frame::new_set_retention(topic_id, max_age_secs, max_bytes);
625 let frame_bytes = encode_frame(&frame);
626
627 trace!(
628 topic_id,
629 max_age_secs, max_bytes, "Setting retention policy"
630 );
631
632 tokio::time::timeout(
633 self.config.write_timeout,
634 self.stream.write_all(&frame_bytes),
635 )
636 .await
637 .map_err(|_| ClientError::Timeout)??;
638
639 let response = self.recv_frame().await?;
640 self.parse_retention_response(response)
641 }
642
643 pub async fn create_topic_with_retention(
650 &mut self,
651 name: &str,
652 max_age_secs: u64,
653 max_bytes: u64,
654 ) -> Result<TopicInfo> {
655 let frame = Frame::new_create_topic_with_retention(name, max_age_secs, max_bytes);
656 let frame_bytes = encode_frame(&frame);
657
658 trace!(
659 name,
660 max_age_secs, max_bytes, "Creating topic with retention"
661 );
662
663 tokio::time::timeout(
664 self.config.write_timeout,
665 self.stream.write_all(&frame_bytes),
666 )
667 .await
668 .map_err(|_| ClientError::Timeout)??;
669
670 let response = self.recv_frame().await?;
671 self.parse_topic_response(response)
672 }
673
674 pub async fn get_cluster_status(&mut self) -> Result<ClusterStatus> {
676 let frame = Frame::new_get_cluster_status();
677 let frame_bytes = encode_frame(&frame);
678
679 tokio::time::timeout(
680 self.config.write_timeout,
681 self.stream.write_all(&frame_bytes),
682 )
683 .await
684 .map_err(|_| ClientError::Timeout)??;
685
686 let response = self.recv_frame().await?;
687 self.parse_cluster_status_response(response)
688 }
689
690 fn parse_cluster_status_response(&self, frame: Frame) -> Result<ClusterStatus> {
691 match frame.frame_type {
692 FrameType::Control(ControlCommand::ClusterStatusResponse) => {
693 let payload = frame.payload.ok_or_else(|| {
694 ClientError::InvalidResponse("Empty cluster status response".to_string())
695 })?;
696 let json: serde_json::Value = serde_json::from_slice(&payload)
697 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
698
699 let peer_states: std::collections::HashMap<u16, String> = json["peer_states"]
700 .as_object()
701 .map(|obj| {
702 obj.iter()
703 .filter_map(|(k, v)| {
704 k.parse::<u16>()
705 .ok()
706 .map(|id| (id, v.as_str().unwrap_or("unknown").to_string()))
707 })
708 .collect()
709 })
710 .unwrap_or_default();
711
712 Ok(ClusterStatus {
713 node_id: json["node_id"].as_u64().unwrap_or(0) as u16,
714 is_leader: json["is_leader"].as_bool().unwrap_or(false),
715 leader_id: json["leader_id"].as_u64().map(|id| id as u16),
716 current_term: json["current_term"].as_u64().unwrap_or(0),
717 node_count: json["node_count"].as_u64().unwrap_or(1) as usize,
718 healthy_nodes: json["healthy_nodes"].as_u64().unwrap_or(1) as usize,
719 quorum_available: json["quorum_available"].as_bool().unwrap_or(true),
720 peer_states,
721 })
722 },
723 FrameType::Control(ControlCommand::ErrorResponse) => {
724 let error_msg = frame
725 .payload
726 .map(|p| String::from_utf8_lossy(&p).to_string())
727 .unwrap_or_else(|| "Unknown error".to_string());
728 Err(ClientError::ServerError(error_msg))
729 },
730 other => Err(ClientError::InvalidResponse(format!(
731 "Expected ClusterStatusResponse, got {:?}",
732 other
733 ))),
734 }
735 }
736
737 pub async fn fetch(
740 &mut self,
741 topic_id: u32,
742 start_offset: u64,
743 max_bytes: u32,
744 ) -> Result<FetchResult> {
745 let frame = Frame::new_fetch(topic_id, start_offset, max_bytes);
746 let frame_bytes = encode_frame(&frame);
747
748 trace!(topic_id, start_offset, max_bytes, "Fetching data");
749
750 tokio::time::timeout(
751 self.config.write_timeout,
752 self.stream.write_all(&frame_bytes),
753 )
754 .await
755 .map_err(|_| ClientError::Timeout)??;
756
757 let response = self.recv_frame().await?;
758 self.parse_fetch_response(response)
759 }
760
761 pub async fn subscribe(
764 &mut self,
765 topic_id: u32,
766 start_offset: u64,
767 max_batch_bytes: u32,
768 consumer_id: u64,
769 ) -> Result<SubscribeResult> {
770 let frame = Frame::new_subscribe(topic_id, start_offset, max_batch_bytes, consumer_id);
771 let frame_bytes = encode_frame(&frame);
772
773 trace!(topic_id, start_offset, consumer_id, "Subscribing to topic");
774
775 tokio::time::timeout(
776 self.config.write_timeout,
777 self.stream.write_all(&frame_bytes),
778 )
779 .await
780 .map_err(|_| ClientError::Timeout)??;
781
782 let response = self.recv_frame().await?;
783 self.parse_subscribe_response(response)
784 }
785
786 pub async fn unsubscribe(&mut self, topic_id: u32, consumer_id: u64) -> Result<()> {
788 let frame = Frame::new_unsubscribe(topic_id, consumer_id);
789 let frame_bytes = encode_frame(&frame);
790
791 trace!(topic_id, consumer_id, "Unsubscribing from topic");
792
793 tokio::time::timeout(
794 self.config.write_timeout,
795 self.stream.write_all(&frame_bytes),
796 )
797 .await
798 .map_err(|_| ClientError::Timeout)??;
799
800 let response = self.recv_frame().await?;
802 match response.frame_type {
803 FrameType::Ack => Ok(()),
804 FrameType::Control(ControlCommand::ErrorResponse) => {
805 let error_msg = response
806 .payload
807 .map(|p| String::from_utf8_lossy(&p).to_string())
808 .unwrap_or_else(|| "Unknown error".to_string());
809 Err(ClientError::ServerError(error_msg))
810 },
811 other => Err(ClientError::InvalidResponse(format!(
812 "Expected Ack, got {:?}",
813 other
814 ))),
815 }
816 }
817
818 pub async fn commit_offset(
820 &mut self,
821 topic_id: u32,
822 consumer_id: u64,
823 offset: u64,
824 ) -> Result<CommitResult> {
825 let frame = Frame::new_commit_offset(topic_id, consumer_id, offset);
826 let frame_bytes = encode_frame(&frame);
827
828 trace!(topic_id, consumer_id, offset, "Committing offset");
829
830 tokio::time::timeout(
831 self.config.write_timeout,
832 self.stream.write_all(&frame_bytes),
833 )
834 .await
835 .map_err(|_| ClientError::Timeout)??;
836
837 let response = self.recv_frame().await?;
838 self.parse_commit_response(response)
839 }
840
841 fn parse_subscribe_response(&self, frame: Frame) -> Result<SubscribeResult> {
842 match frame.frame_type {
843 FrameType::Control(ControlCommand::SubscribeAck) => {
844 let payload = frame.payload.ok_or_else(|| {
845 ClientError::InvalidResponse("Empty subscribe response".to_string())
846 })?;
847
848 if payload.len() < 16 {
849 return Err(ClientError::ProtocolError(
850 "Subscribe response too small".to_string(),
851 ));
852 }
853
854 let consumer_id = u64::from_le_bytes([
855 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
856 payload[6], payload[7],
857 ]);
858 let start_offset = u64::from_le_bytes([
859 payload[8],
860 payload[9],
861 payload[10],
862 payload[11],
863 payload[12],
864 payload[13],
865 payload[14],
866 payload[15],
867 ]);
868
869 Ok(SubscribeResult {
870 consumer_id,
871 start_offset,
872 })
873 },
874 FrameType::Control(ControlCommand::ErrorResponse) => {
875 let error_msg = frame
876 .payload
877 .map(|p| String::from_utf8_lossy(&p).to_string())
878 .unwrap_or_else(|| "Unknown error".to_string());
879 Err(ClientError::ServerError(error_msg))
880 },
881 other => Err(ClientError::InvalidResponse(format!(
882 "Expected SubscribeAck, got {:?}",
883 other
884 ))),
885 }
886 }
887
888 fn parse_commit_response(&self, frame: Frame) -> Result<CommitResult> {
889 match frame.frame_type {
890 FrameType::Control(ControlCommand::CommitAck) => {
891 let payload = frame.payload.ok_or_else(|| {
892 ClientError::InvalidResponse("Empty commit response".to_string())
893 })?;
894
895 if payload.len() < 16 {
896 return Err(ClientError::ProtocolError(
897 "Commit response too small".to_string(),
898 ));
899 }
900
901 let consumer_id = u64::from_le_bytes([
902 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
903 payload[6], payload[7],
904 ]);
905 let committed_offset = u64::from_le_bytes([
906 payload[8],
907 payload[9],
908 payload[10],
909 payload[11],
910 payload[12],
911 payload[13],
912 payload[14],
913 payload[15],
914 ]);
915
916 Ok(CommitResult {
917 consumer_id,
918 committed_offset,
919 })
920 },
921 FrameType::Control(ControlCommand::ErrorResponse) => {
922 let error_msg = frame
923 .payload
924 .map(|p| String::from_utf8_lossy(&p).to_string())
925 .unwrap_or_else(|| "Unknown error".to_string());
926 Err(ClientError::ServerError(error_msg))
927 },
928 other => Err(ClientError::InvalidResponse(format!(
929 "Expected CommitAck, got {:?}",
930 other
931 ))),
932 }
933 }
934
935 fn parse_fetch_response(&self, frame: Frame) -> Result<FetchResult> {
936 match frame.frame_type {
937 FrameType::Control(ControlCommand::FetchResponse) => {
938 let payload = frame.payload.ok_or_else(|| {
939 ClientError::InvalidResponse("Empty fetch response".to_string())
940 })?;
941
942 if payload.len() < 16 {
943 return Err(ClientError::ProtocolError(
944 "Fetch response too small".to_string(),
945 ));
946 }
947
948 let next_offset = u64::from_le_bytes([
949 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
950 payload[6], payload[7],
951 ]);
952 let bytes_returned =
953 u32::from_le_bytes([payload[8], payload[9], payload[10], payload[11]]);
954 let record_count =
955 u32::from_le_bytes([payload[12], payload[13], payload[14], payload[15]]);
956 let data = payload.slice(16..);
957
958 Ok(FetchResult {
959 data,
960 next_offset,
961 bytes_returned,
962 record_count,
963 })
964 },
965 FrameType::Control(ControlCommand::ErrorResponse) => {
966 let error_msg = frame
967 .payload
968 .map(|p| String::from_utf8_lossy(&p).to_string())
969 .unwrap_or_else(|| "Unknown error".to_string());
970 Err(ClientError::ServerError(error_msg))
971 },
972 other => Err(ClientError::InvalidResponse(format!(
973 "Expected FetchResponse, got {:?}",
974 other
975 ))),
976 }
977 }
978
979 fn parse_delete_response(&self, frame: Frame) -> Result<()> {
980 expect_success_response(frame)
981 }
982
983 fn parse_retention_response(&self, frame: Frame) -> Result<()> {
984 expect_success_response(frame)
985 }
986
987 fn parse_topic_response(&self, frame: Frame) -> Result<TopicInfo> {
988 match frame.frame_type {
989 FrameType::Control(ControlCommand::TopicResponse) => {
990 let payload = frame.payload.ok_or_else(|| {
991 ClientError::InvalidResponse("Empty topic response".to_string())
992 })?;
993 let json: serde_json::Value = serde_json::from_slice(&payload)
994 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
995
996 let retention = if json.get("retention").is_some() {
997 Some(RetentionInfo {
998 max_age_secs: json["retention"]["max_age_secs"].as_u64().unwrap_or(0),
999 max_bytes: json["retention"]["max_bytes"].as_u64().unwrap_or(0),
1000 })
1001 } else {
1002 None
1003 };
1004
1005 Ok(TopicInfo {
1006 id: json["id"].as_u64().unwrap_or(0) as u32,
1007 name: json["name"].as_str().unwrap_or("").to_string(),
1008 created_at: json["created_at"].as_u64().unwrap_or(0),
1009 retention,
1010 })
1011 },
1012 FrameType::Control(ControlCommand::ErrorResponse) => {
1013 let error_msg = frame
1014 .payload
1015 .map(|p| String::from_utf8_lossy(&p).to_string())
1016 .unwrap_or_else(|| "Unknown error".to_string());
1017 Err(ClientError::ServerError(error_msg))
1018 },
1019 other => Err(ClientError::InvalidResponse(format!(
1020 "Expected TopicResponse, got {:?}",
1021 other
1022 ))),
1023 }
1024 }
1025
1026 fn parse_topic_list_response(&self, frame: Frame) -> Result<Vec<TopicInfo>> {
1027 match frame.frame_type {
1028 FrameType::Control(ControlCommand::TopicResponse) => {
1029 let payload = frame.payload.ok_or_else(|| {
1030 ClientError::InvalidResponse("Empty topic list response".to_string())
1031 })?;
1032 let json: serde_json::Value = serde_json::from_slice(&payload)
1033 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
1034
1035 let topics = json["topics"]
1036 .as_array()
1037 .map(|arr| {
1038 arr.iter()
1039 .map(|t| {
1040 let retention = if t.get("retention").is_some() {
1041 Some(RetentionInfo {
1042 max_age_secs: t["retention"]["max_age_secs"]
1043 .as_u64()
1044 .unwrap_or(0),
1045 max_bytes: t["retention"]["max_bytes"]
1046 .as_u64()
1047 .unwrap_or(0),
1048 })
1049 } else {
1050 None
1051 };
1052 TopicInfo {
1053 id: t["id"].as_u64().unwrap_or(0) as u32,
1054 name: t["name"].as_str().unwrap_or("").to_string(),
1055 created_at: t["created_at"].as_u64().unwrap_or(0),
1056 retention,
1057 }
1058 })
1059 .collect()
1060 })
1061 .unwrap_or_default();
1062
1063 Ok(topics)
1064 },
1065 FrameType::Control(ControlCommand::ErrorResponse) => {
1066 let error_msg = frame
1067 .payload
1068 .map(|p| String::from_utf8_lossy(&p).to_string())
1069 .unwrap_or_else(|| "Unknown error".to_string());
1070 Err(ClientError::ServerError(error_msg))
1071 },
1072 other => Err(ClientError::InvalidResponse(format!(
1073 "Expected TopicResponse, got {:?}",
1074 other
1075 ))),
1076 }
1077 }
1078
1079 async fn recv_frame(&mut self) -> Result<Frame> {
1080 const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
1082
1083 loop {
1084 if self.read_offset >= LWP_HEADER_SIZE {
1085 let payload_len = u32::from_le_bytes([
1087 self.read_buffer[32],
1088 self.read_buffer[33],
1089 self.read_buffer[34],
1090 self.read_buffer[35],
1091 ]) as usize;
1092 let total_frame_size = LWP_HEADER_SIZE + payload_len;
1093 if total_frame_size > MAX_FRAME_SIZE {
1094 return Err(ClientError::ServerError(format!(
1095 "Frame too large: {} bytes",
1096 total_frame_size
1097 )));
1098 }
1099 if total_frame_size > self.read_buffer.len() {
1100 self.read_buffer.resize(total_frame_size, 0);
1101 }
1102
1103 if let Some((frame, consumed)) = parse_frame(&self.read_buffer[..self.read_offset])?
1104 {
1105 self.read_buffer.copy_within(consumed..self.read_offset, 0);
1106 self.read_offset -= consumed;
1107 if self.read_buffer.len() > 64 * 1024 && self.read_offset < 64 * 1024 {
1109 self.read_buffer.resize(64 * 1024, 0);
1110 }
1111 return Ok(frame);
1112 }
1113 }
1114
1115 let n = tokio::time::timeout(
1116 self.config.read_timeout,
1117 self.stream.read(&mut self.read_buffer[self.read_offset..]),
1118 )
1119 .await
1120 .map_err(|_| ClientError::Timeout)??;
1121
1122 if n == 0 {
1123 return Err(ClientError::ConnectionClosed);
1124 }
1125
1126 self.read_offset += n;
1127 }
1128 }
1129
1130 pub fn config(&self) -> &ClientConfig {
1132 &self.config
1133 }
1134
1135 pub async fn close(mut self) -> Result<()> {
1137 self.stream.shutdown().await?;
1138 Ok(())
1139 }
1140}
1141
1142impl std::fmt::Debug for LanceClient {
1143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1144 f.debug_struct("LanceClient")
1145 .field("addr", &self.config.addr)
1146 .field("batch_id", &self.batch_id.load(Ordering::SeqCst))
1147 .finish()
1148 }
1149}