1use crate::error::{ClientError, Result};
2use crate::tls::TlsClientConfig;
3use bytes::Bytes;
4use lnc_network::{
5 ControlCommand, Frame, FrameType, LWP_HEADER_SIZE, TlsConnector, encode_frame, parse_frame,
6};
7use std::net::SocketAddr;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::task::{Context, Poll};
11use std::time::Duration;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
13use tokio::net::TcpStream;
14use tokio::net::lookup_host;
15use tracing::{debug, trace, warn};
16
17#[allow(clippy::large_enum_variant)]
19pub enum ClientStream {
20 Tcp(TcpStream),
22 Tls(tokio_rustls::client::TlsStream<TcpStream>),
24}
25
26impl AsyncRead for ClientStream {
27 fn poll_read(
31 self: Pin<&mut Self>,
32 cx: &mut Context<'_>,
33 buf: &mut ReadBuf<'_>,
34 ) -> Poll<std::io::Result<()>> {
35 match self.get_mut() {
36 ClientStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
37 ClientStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
38 }
39 }
40}
41
42impl AsyncWrite for ClientStream {
43 fn poll_write(
48 self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &[u8],
51 ) -> Poll<std::io::Result<usize>> {
52 match self.get_mut() {
53 ClientStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
54 ClientStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
55 }
56 }
57
58 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
62 match self.get_mut() {
63 ClientStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
64 ClientStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
65 }
66 }
67
68 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
72 match self.get_mut() {
73 ClientStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
74 ClientStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
75 }
76 }
77}
78
79fn extract_error_message(frame: &Frame) -> String {
81 frame
82 .payload
83 .as_ref()
84 .map(|p| String::from_utf8_lossy(p).to_string())
85 .unwrap_or_else(|| "Unknown error".to_string())
86}
87
88#[allow(dead_code)] fn expect_frame_type(frame: Frame, expected: ControlCommand, expected_name: &str) -> Result<Frame> {
91 match frame.frame_type {
92 FrameType::Control(cmd) if cmd == expected => Ok(frame),
93 FrameType::Control(ControlCommand::ErrorResponse) => {
94 Err(ClientError::ServerError(extract_error_message(&frame)))
95 },
96 other => Err(ClientError::InvalidResponse(format!(
97 "Expected {}, got {:?}",
98 expected_name, other
99 ))),
100 }
101}
102
103fn expect_success_response(frame: Frame) -> Result<()> {
105 match frame.frame_type {
106 FrameType::Control(ControlCommand::TopicResponse) => Ok(()),
107 FrameType::Control(ControlCommand::ErrorResponse) => {
108 Err(ClientError::ServerError(extract_error_message(&frame)))
109 },
110 other => Err(ClientError::InvalidResponse(format!(
111 "Expected TopicResponse, got {:?}",
112 other
113 ))),
114 }
115}
116
117#[derive(Debug, Clone, Default)]
119pub struct AuthConfig {
120 pub mtls_enabled: bool,
122 pub client_cert_path: Option<String>,
124 pub client_key_path: Option<String>,
126}
127
128#[derive(Debug, Clone, Default)]
130pub struct RetentionInfo {
131 pub max_age_secs: u64,
133 pub max_bytes: u64,
135}
136
137#[derive(Debug, Clone)]
139pub struct TopicInfo {
140 pub id: u32,
142 pub name: String,
144 pub created_at: u64,
146 pub topic_epoch: u64,
148 pub retention: Option<RetentionInfo>,
150}
151
152#[derive(Debug, Clone)]
154pub struct FetchResult {
155 pub data: Bytes,
157 pub next_offset: u64,
159 pub bytes_returned: u32,
161 pub record_count: u32,
163}
164
165#[derive(Debug, Clone)]
167pub struct SubscribeResult {
168 pub consumer_id: u64,
170 pub start_offset: u64,
172}
173
174#[derive(Debug, Clone)]
176pub struct CommitResult {
177 pub consumer_id: u64,
179 pub committed_offset: u64,
181}
182
183#[derive(Debug, Clone)]
185pub struct ClusterStatus {
186 pub node_id: u16,
187 pub is_leader: bool,
188 pub leader_id: Option<u16>,
189 pub current_term: u64,
190 pub node_count: usize,
191 pub healthy_nodes: usize,
192 pub quorum_available: bool,
193 pub peer_states: std::collections::HashMap<u16, String>,
194}
195
196#[derive(Debug, Clone)]
198pub struct ClientConfig {
199 pub addr: String,
201 pub connect_timeout: Duration,
203 pub read_timeout: Duration,
205 pub write_timeout: Duration,
207 pub keepalive_interval: Duration,
209 pub tls: Option<TlsClientConfig>,
211}
212
213impl Default for ClientConfig {
214 fn default() -> Self {
215 Self {
216 addr: "127.0.0.1:1992".to_string(),
217 connect_timeout: Duration::from_secs(10),
218 read_timeout: Duration::from_secs(30),
219 write_timeout: Duration::from_secs(10),
220 keepalive_interval: Duration::from_secs(10),
221 tls: None,
222 }
223 }
224}
225
226impl ClientConfig {
227 pub fn new(addr: impl Into<String>) -> Self {
233 Self {
234 addr: addr.into(),
235 ..Default::default()
236 }
237 }
238
239 pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
250 self.tls = Some(tls_config);
251 self
252 }
253
254 pub fn is_tls_enabled(&self) -> bool {
256 self.tls.is_some()
257 }
258}
259
260pub struct LanceClient {
264 stream: ClientStream,
265 config: ClientConfig,
266 batch_id: AtomicU64,
267 read_buffer: Vec<u8>,
268 read_offset: usize,
269}
270
271impl LanceClient {
272 async fn resolve_address(addr: &str) -> Result<SocketAddr> {
287 if let Ok(socket_addr) = addr.parse::<SocketAddr>() {
289 return Ok(socket_addr);
290 }
291
292 let mut addrs = lookup_host(addr).await.map_err(|e| {
294 ClientError::ProtocolError(format!("DNS resolution failed for '{}': {}", addr, e))
295 })?;
296
297 addrs
298 .next()
299 .ok_or_else(|| ClientError::ProtocolError(format!("No addresses found for '{}'", addr)))
300 }
301
302 pub async fn connect(config: ClientConfig) -> Result<Self> {
307 if let Some(ref tls_config) = config.tls {
309 return Self::connect_tls(config.clone(), tls_config.clone()).await;
310 }
311
312 debug!(addr = %config.addr, "Connecting to LANCE server");
313
314 let socket_addr = Self::resolve_address(&config.addr).await?;
316 debug!(resolved_addr = %socket_addr, "Resolved server address");
317
318 let stream = tokio::time::timeout(config.connect_timeout, TcpStream::connect(socket_addr))
319 .await
320 .map_err(|_| ClientError::Timeout)?
321 .map_err(ClientError::ConnectionFailed)?;
322
323 stream.set_nodelay(true)?;
324
325 debug!(addr = %config.addr, "Connected to LANCE server");
326
327 Ok(Self {
328 stream: ClientStream::Tcp(stream),
329 config,
330 batch_id: AtomicU64::new(0),
331 read_buffer: vec![0u8; 64 * 1024],
332 read_offset: 0,
333 })
334 }
335
336 pub async fn connect_tls(config: ClientConfig, tls_config: TlsClientConfig) -> Result<Self> {
353 debug!(addr = %config.addr, "Connecting to LANCE server with TLS");
354
355 let socket_addr = Self::resolve_address(&config.addr).await?;
357 debug!(resolved_addr = %socket_addr, "Resolved server address");
358
359 let tcp_stream =
361 tokio::time::timeout(config.connect_timeout, TcpStream::connect(socket_addr))
362 .await
363 .map_err(|_| ClientError::Timeout)?
364 .map_err(ClientError::ConnectionFailed)?;
365
366 tcp_stream.set_nodelay(true)?;
367
368 let network_config = tls_config.to_network_config();
370 let connector =
371 TlsConnector::new(network_config).map_err(|e| ClientError::TlsError(e.to_string()))?;
372
373 let server_name = tls_config.server_name.unwrap_or_else(|| {
375 config
377 .addr
378 .rsplit_once(':')
379 .map(|(host, _)| host.to_string())
380 .unwrap_or_else(|| socket_addr.ip().to_string())
381 });
382
383 let tls_stream = connector
385 .connect(&server_name, tcp_stream)
386 .await
387 .map_err(|e| ClientError::TlsError(e.to_string()))?;
388
389 debug!(addr = %config.addr, "TLS connection established");
390
391 Ok(Self {
392 stream: ClientStream::Tls(tls_stream),
393 config,
394 batch_id: AtomicU64::new(0),
395 read_buffer: vec![0u8; 64 * 1024],
396 read_offset: 0,
397 })
398 }
399
400 pub async fn connect_to(addr: &str) -> Result<Self> {
405 Self::connect(ClientConfig::new(addr)).await
406 }
407
408 pub async fn connect_tls_to(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
413 Self::connect_tls(ClientConfig::new(addr), tls_config).await
414 }
415
416 fn next_batch_id(&self) -> u64 {
417 self.batch_id.fetch_add(1, Ordering::SeqCst) + 1
418 }
419
420 pub async fn send_ingest(&mut self, payload: Bytes, record_count: u32) -> Result<u64> {
434 self.send_ingest_to_topic(0, payload, record_count, None)
435 .await
436 }
437
438 pub async fn send_ingest_to_topic(
451 &mut self,
452 topic_id: u32,
453 payload: Bytes,
454 record_count: u32,
455 _auth_config: Option<&AuthConfig>,
456 ) -> Result<u64> {
457 let batch_id = self.next_batch_id();
458 let timestamp_ns = std::time::SystemTime::now()
459 .duration_since(std::time::UNIX_EPOCH)
460 .map(|d| d.as_nanos() as u64)
461 .unwrap_or(0);
462
463 let frame =
464 Frame::new_ingest_with_topic(batch_id, timestamp_ns, record_count, payload, topic_id);
465 let frame_bytes = encode_frame(&frame);
466
467 trace!(
468 batch_id,
469 topic_id,
470 payload_len = frame.payload_length(),
471 "Sending ingest frame"
472 );
473
474 tokio::time::timeout(
475 self.config.write_timeout,
476 self.stream.write_all(&frame_bytes),
477 )
478 .await
479 .map_err(|_| ClientError::Timeout)??;
480
481 Ok(batch_id)
482 }
483
484 pub async fn send_ingest_sync(&mut self, payload: Bytes, record_count: u32) -> Result<u64> {
495 self.send_ingest_to_topic_sync(0, payload, record_count, None)
496 .await
497 }
498
499 pub async fn send_ingest_to_topic_sync(
512 &mut self,
513 topic_id: u32,
514 payload: Bytes,
515 record_count: u32,
516 auth_config: Option<&AuthConfig>,
517 ) -> Result<u64> {
518 let batch_id = self
519 .send_ingest_to_topic(topic_id, payload, record_count, auth_config)
520 .await?;
521 self.wait_for_ack(batch_id).await
522 }
523
524 async fn wait_for_ack(&mut self, expected_batch_id: u64) -> Result<u64> {
537 let frame = self.recv_frame().await?;
538
539 match frame.frame_type {
540 FrameType::Ack => {
541 let acked_id = frame.batch_id();
542 if acked_id != expected_batch_id {
543 return Err(ClientError::InvalidResponse(format!(
544 "Ack batch_id mismatch: sent {}, received {}",
545 expected_batch_id, acked_id
546 )));
547 }
548 trace!(batch_id = acked_id, "Received ack");
549 Ok(acked_id)
550 },
551 FrameType::Control(ControlCommand::ErrorResponse) => {
552 let error_msg = frame
553 .payload
554 .map(|p| String::from_utf8_lossy(&p).to_string())
555 .unwrap_or_else(|| "Unknown error".to_string());
556 Err(ClientError::ServerError(error_msg))
557 },
558 FrameType::Backpressure => {
559 warn!("Server signaled backpressure");
560 Err(ClientError::ServerBackpressure)
561 },
562 other => Err(ClientError::InvalidResponse(format!(
563 "Expected Ack, got {:?}",
564 other
565 ))),
566 }
567 }
568
569 pub async fn recv_ack(&mut self) -> Result<u64> {
580 let frame = self.recv_frame().await?;
581
582 match frame.frame_type {
583 FrameType::Ack => {
584 trace!(batch_id = frame.batch_id(), "Received ack");
585 Ok(frame.batch_id())
586 },
587 FrameType::Backpressure => {
588 warn!("Server signaled backpressure");
589 Err(ClientError::ServerBackpressure)
590 },
591 other => Err(ClientError::InvalidResponse(format!(
592 "Expected Ack, got {:?}",
593 other
594 ))),
595 }
596 }
597
598 pub async fn send_keepalive(&mut self) -> Result<()> {
605 let frame = Frame::new_keepalive();
606 let frame_bytes = encode_frame(&frame);
607
608 trace!("Sending keepalive");
609
610 tokio::time::timeout(
611 self.config.write_timeout,
612 self.stream.write_all(&frame_bytes),
613 )
614 .await
615 .map_err(|_| ClientError::Timeout)??;
616
617 Ok(())
618 }
619
620 pub async fn recv_keepalive(&mut self) -> Result<()> {
630 let frame = self.recv_frame().await?;
631
632 match frame.frame_type {
633 FrameType::Keepalive => {
634 trace!("Received keepalive response");
635 Ok(())
636 },
637 other => Err(ClientError::InvalidResponse(format!(
638 "Expected Keepalive, got {:?}",
639 other
640 ))),
641 }
642 }
643
644 pub async fn ping(&mut self) -> Result<Duration> {
646 let start = std::time::Instant::now();
647 self.send_keepalive().await?;
648 self.recv_keepalive().await?;
649 Ok(start.elapsed())
650 }
651
652 pub async fn create_topic(&mut self, name: &str) -> Result<TopicInfo> {
654 let frame = Frame::new_create_topic(name);
655 let frame_bytes = encode_frame(&frame);
656
657 trace!(topic_name = %name, "Creating topic");
658
659 tokio::time::timeout(
660 self.config.write_timeout,
661 self.stream.write_all(&frame_bytes),
662 )
663 .await
664 .map_err(|_| ClientError::Timeout)??;
665
666 let response = self.recv_frame().await?;
667 self.parse_topic_response(response)
668 }
669
670 pub async fn ensure_topic(
675 &mut self,
676 name: &str,
677 max_attempts: usize,
678 base_backoff_ms: u64,
679 ) -> Result<TopicInfo> {
680 let attempts = max_attempts.max(1);
681 let mut last_error: Option<ClientError> = None;
682 let mut saw_retryable_error = false;
683
684 for attempt in 1..=attempts {
685 let mut retryable_this_attempt = false;
686
687 match self.create_topic(name).await {
688 Ok(info) => {
689 trace!(
690 topic_id = info.id,
691 topic_name = %info.name,
692 attempt,
693 max_attempts = attempts,
694 "Topic ensured via create_topic"
695 );
696 return Ok(info);
697 },
698 Err(create_err) => {
699 if create_err.is_retryable() {
700 retryable_this_attempt = true;
701 saw_retryable_error = true;
702 }
703 last_error = Some(ClientError::ServerError(create_err.to_string()));
704 warn!(
705 topic_name = %name,
706 attempt,
707 max_attempts = attempts,
708 error = %create_err,
709 "create_topic failed during ensure_topic; retrying with list fallback"
710 );
711 },
712 }
713
714 match self.list_topics().await {
715 Ok(topics) => {
716 if let Some(topic) = topics.into_iter().find(|t| t.name == name) {
717 trace!(
718 topic_id = topic.id,
719 topic_name = %topic.name,
720 attempt,
721 max_attempts = attempts,
722 "Topic ensured via list_topics fallback"
723 );
724 return Ok(topic);
725 }
726 },
727 Err(list_err) => {
728 if list_err.is_retryable() {
729 retryable_this_attempt = true;
730 saw_retryable_error = true;
731 }
732 last_error = Some(ClientError::ServerError(list_err.to_string()));
733 warn!(
734 topic_name = %name,
735 attempt,
736 max_attempts = attempts,
737 error = %list_err,
738 "list_topics failed during ensure_topic"
739 );
740 },
741 }
742
743 if attempt < attempts {
744 let backoff_ms = if retryable_this_attempt {
745 base_backoff_ms.saturating_mul(attempt as u64).max(1)
746 } else {
747 base_backoff_ms.max(1)
749 };
750 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
751
752 let reconnect_config = self.config.clone();
755 match Self::connect(reconnect_config).await {
756 Ok(new_client) => {
757 *self = new_client;
758 },
759 Err(reconnect_err) => {
760 warn!(
761 topic_name = %name,
762 attempt,
763 max_attempts = attempts,
764 error = %reconnect_err,
765 "ensure_topic reconnect attempt failed"
766 );
767 last_error = Some(reconnect_err);
768 },
769 }
770 }
771 }
772
773 if let Some(err) = last_error {
774 return Err(ClientError::ServerError(format!(
775 "ensure_topic('{}') failed after {} attempts: {}",
776 name, attempts, err
777 )));
778 }
779
780 if saw_retryable_error {
781 return Err(ClientError::ServerError(format!(
782 "ensure_topic('{}') exhausted {} retryable attempts",
783 name, attempts
784 )));
785 }
786
787 Err(ClientError::ServerError(format!(
788 "topic '{}' not found after {} ensure_topic attempts",
789 name, attempts
790 )))
791 }
792
793 pub async fn ensure_topic_default(&mut self, name: &str) -> Result<TopicInfo> {
795 const DEFAULT_ENSURE_TOPIC_ATTEMPTS: usize = 20;
796 const DEFAULT_ENSURE_TOPIC_BACKOFF_MS: u64 = 500;
797 self.ensure_topic(
798 name,
799 DEFAULT_ENSURE_TOPIC_ATTEMPTS,
800 DEFAULT_ENSURE_TOPIC_BACKOFF_MS,
801 )
802 .await
803 }
804
805 pub async fn list_topics(&mut self) -> Result<Vec<TopicInfo>> {
807 let frame = Frame::new_list_topics();
808 let frame_bytes = encode_frame(&frame);
809
810 trace!("Listing topics");
811
812 tokio::time::timeout(
813 self.config.write_timeout,
814 self.stream.write_all(&frame_bytes),
815 )
816 .await
817 .map_err(|_| ClientError::Timeout)??;
818
819 let response = self.recv_frame().await?;
820 self.parse_topic_list_response(response)
821 }
822
823 pub async fn get_topic(&mut self, topic_id: u32) -> Result<TopicInfo> {
825 let frame = Frame::new_get_topic(topic_id);
826 let frame_bytes = encode_frame(&frame);
827
828 trace!(topic_id, "Getting topic");
829
830 tokio::time::timeout(
831 self.config.write_timeout,
832 self.stream.write_all(&frame_bytes),
833 )
834 .await
835 .map_err(|_| ClientError::Timeout)??;
836
837 let response = self.recv_frame().await?;
838 self.parse_topic_response(response)
839 }
840
841 pub async fn delete_topic(&mut self, topic_id: u32) -> Result<()> {
843 let frame = Frame::new_delete_topic(topic_id);
844 let frame_bytes = encode_frame(&frame);
845
846 trace!(topic_id, "Deleting topic");
847
848 tokio::time::timeout(
849 self.config.write_timeout,
850 self.stream.write_all(&frame_bytes),
851 )
852 .await
853 .map_err(|_| ClientError::Timeout)??;
854
855 let response = self.recv_frame().await?;
856 self.parse_delete_response(response)
857 }
858
859 pub async fn set_retention(
867 &mut self,
868 topic_id: u32,
869 max_age_secs: u64,
870 max_bytes: u64,
871 ) -> Result<()> {
872 let frame = Frame::new_set_retention(topic_id, max_age_secs, max_bytes);
873 let frame_bytes = encode_frame(&frame);
874
875 trace!(
876 topic_id,
877 max_age_secs, max_bytes, "Setting retention policy"
878 );
879
880 tokio::time::timeout(
881 self.config.write_timeout,
882 self.stream.write_all(&frame_bytes),
883 )
884 .await
885 .map_err(|_| ClientError::Timeout)??;
886
887 let response = self.recv_frame().await?;
888 self.parse_retention_response(response)
889 }
890
891 pub async fn create_topic_with_retention(
898 &mut self,
899 name: &str,
900 max_age_secs: u64,
901 max_bytes: u64,
902 ) -> Result<TopicInfo> {
903 let frame = Frame::new_create_topic_with_retention(name, max_age_secs, max_bytes);
904 let frame_bytes = encode_frame(&frame);
905
906 trace!(
907 name,
908 max_age_secs, max_bytes, "Creating topic with retention"
909 );
910
911 tokio::time::timeout(
912 self.config.write_timeout,
913 self.stream.write_all(&frame_bytes),
914 )
915 .await
916 .map_err(|_| ClientError::Timeout)??;
917
918 let response = self.recv_frame().await?;
919 self.parse_topic_response(response)
920 }
921
922 pub async fn get_cluster_status(&mut self) -> Result<ClusterStatus> {
924 let frame = Frame::new_get_cluster_status();
925 let frame_bytes = encode_frame(&frame);
926
927 tokio::time::timeout(
928 self.config.write_timeout,
929 self.stream.write_all(&frame_bytes),
930 )
931 .await
932 .map_err(|_| ClientError::Timeout)??;
933
934 let response = self.recv_frame().await?;
935 self.parse_cluster_status_response(response)
936 }
937
938 fn parse_cluster_status_response(&self, frame: Frame) -> Result<ClusterStatus> {
939 match frame.frame_type {
940 FrameType::Control(ControlCommand::ClusterStatusResponse) => {
941 let payload = frame.payload.ok_or_else(|| {
942 ClientError::InvalidResponse("Empty cluster status response".to_string())
943 })?;
944 let json: serde_json::Value = serde_json::from_slice(&payload)
945 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
946
947 let peer_states: std::collections::HashMap<u16, String> = json["peer_states"]
948 .as_object()
949 .map(|obj| {
950 obj.iter()
951 .filter_map(|(k, v)| {
952 k.parse::<u16>()
953 .ok()
954 .map(|id| (id, v.as_str().unwrap_or("unknown").to_string()))
955 })
956 .collect()
957 })
958 .unwrap_or_default();
959
960 Ok(ClusterStatus {
961 node_id: json["node_id"].as_u64().unwrap_or(0) as u16,
962 is_leader: json["is_leader"].as_bool().unwrap_or(false),
963 leader_id: json["leader_id"].as_u64().map(|id| id as u16),
964 current_term: json["current_term"].as_u64().unwrap_or(0),
965 node_count: json["node_count"].as_u64().unwrap_or(1) as usize,
966 healthy_nodes: json["healthy_nodes"].as_u64().unwrap_or(1) as usize,
967 quorum_available: json["quorum_available"].as_bool().unwrap_or(true),
968 peer_states,
969 })
970 },
971 FrameType::Control(ControlCommand::ErrorResponse) => {
972 let error_msg = frame
973 .payload
974 .map(|p| String::from_utf8_lossy(&p).to_string())
975 .unwrap_or_else(|| "Unknown error".to_string());
976 Err(ClientError::ServerError(error_msg))
977 },
978 other => Err(ClientError::InvalidResponse(format!(
979 "Expected ClusterStatusResponse, got {:?}",
980 other
981 ))),
982 }
983 }
984
985 pub async fn fetch(
988 &mut self,
989 topic_id: u32,
990 start_offset: u64,
991 max_bytes: u32,
992 ) -> Result<FetchResult> {
993 let frame = Frame::new_fetch(topic_id, start_offset, max_bytes);
994 let frame_bytes = encode_frame(&frame);
995
996 trace!(topic_id, start_offset, max_bytes, "Fetching data");
997
998 tokio::time::timeout(
999 self.config.write_timeout,
1000 self.stream.write_all(&frame_bytes),
1001 )
1002 .await
1003 .map_err(|_| ClientError::Timeout)??;
1004
1005 let response = self.recv_frame().await?;
1006 self.parse_fetch_response(response)
1007 }
1008
1009 pub async fn subscribe(
1026 &mut self,
1027 topic_id: u32,
1028 start_offset: u64,
1029 max_batch_bytes: u32,
1030 consumer_id: u64,
1031 ) -> Result<SubscribeResult> {
1032 let frame = Frame::new_subscribe(topic_id, start_offset, max_batch_bytes, consumer_id);
1033 let frame_bytes = encode_frame(&frame);
1034
1035 trace!(topic_id, start_offset, consumer_id, "Subscribing to topic");
1036
1037 tokio::time::timeout(
1038 self.config.write_timeout,
1039 self.stream.write_all(&frame_bytes),
1040 )
1041 .await
1042 .map_err(|_| ClientError::Timeout)??;
1043
1044 let response = self.recv_frame().await?;
1045 self.parse_subscribe_response(response)
1046 }
1047
1048 pub async fn unsubscribe(&mut self, topic_id: u32, consumer_id: u64) -> Result<()> {
1062 let frame = Frame::new_unsubscribe(topic_id, consumer_id);
1063 let frame_bytes = encode_frame(&frame);
1064
1065 trace!(topic_id, consumer_id, "Unsubscribing from topic");
1066
1067 tokio::time::timeout(
1068 self.config.write_timeout,
1069 self.stream.write_all(&frame_bytes),
1070 )
1071 .await
1072 .map_err(|_| ClientError::Timeout)??;
1073
1074 let response = self.recv_frame().await?;
1076 match response.frame_type {
1077 FrameType::Ack => Ok(()),
1078 FrameType::Control(ControlCommand::ErrorResponse) => {
1079 let error_msg = response
1080 .payload
1081 .map(|p| String::from_utf8_lossy(&p).to_string())
1082 .unwrap_or_else(|| "Unknown error".to_string());
1083 Err(ClientError::ServerError(error_msg))
1084 },
1085 other => Err(ClientError::InvalidResponse(format!(
1086 "Expected Ack, got {:?}",
1087 other
1088 ))),
1089 }
1090 }
1091
1092 pub async fn commit_offset(
1094 &mut self,
1095 topic_id: u32,
1096 consumer_id: u64,
1097 offset: u64,
1098 ) -> Result<CommitResult> {
1099 let frame = Frame::new_commit_offset(topic_id, consumer_id, offset);
1100 let frame_bytes = encode_frame(&frame);
1101
1102 trace!(topic_id, consumer_id, offset, "Committing offset");
1103
1104 tokio::time::timeout(
1105 self.config.write_timeout,
1106 self.stream.write_all(&frame_bytes),
1107 )
1108 .await
1109 .map_err(|_| ClientError::Timeout)??;
1110
1111 let response = self.recv_frame().await?;
1112 self.parse_commit_response(response)
1113 }
1114
1115 fn parse_subscribe_response(&self, frame: Frame) -> Result<SubscribeResult> {
1116 match frame.frame_type {
1117 FrameType::Control(ControlCommand::SubscribeAck) => {
1118 let payload = frame.payload.ok_or_else(|| {
1119 ClientError::InvalidResponse("Empty subscribe response".to_string())
1120 })?;
1121
1122 if payload.len() < 16 {
1123 return Err(ClientError::ProtocolError(
1124 "Subscribe response too small".to_string(),
1125 ));
1126 }
1127
1128 let consumer_id = u64::from_le_bytes([
1129 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
1130 payload[6], payload[7],
1131 ]);
1132 let start_offset = u64::from_le_bytes([
1133 payload[8],
1134 payload[9],
1135 payload[10],
1136 payload[11],
1137 payload[12],
1138 payload[13],
1139 payload[14],
1140 payload[15],
1141 ]);
1142
1143 Ok(SubscribeResult {
1144 consumer_id,
1145 start_offset,
1146 })
1147 },
1148 FrameType::Control(ControlCommand::ErrorResponse) => {
1149 let error_msg = frame
1150 .payload
1151 .map(|p| String::from_utf8_lossy(&p).to_string())
1152 .unwrap_or_else(|| "Unknown error".to_string());
1153 Err(ClientError::ServerError(error_msg))
1154 },
1155 other => Err(ClientError::InvalidResponse(format!(
1156 "Expected SubscribeAck, got {:?}",
1157 other
1158 ))),
1159 }
1160 }
1161
1162 fn parse_commit_response(&self, frame: Frame) -> Result<CommitResult> {
1163 match frame.frame_type {
1164 FrameType::Control(ControlCommand::CommitAck) => {
1165 let payload = frame.payload.ok_or_else(|| {
1166 ClientError::InvalidResponse("Empty commit response".to_string())
1167 })?;
1168
1169 if payload.len() < 16 {
1170 return Err(ClientError::ProtocolError(
1171 "Commit response too small".to_string(),
1172 ));
1173 }
1174
1175 let consumer_id = u64::from_le_bytes([
1176 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
1177 payload[6], payload[7],
1178 ]);
1179 let committed_offset = u64::from_le_bytes([
1180 payload[8],
1181 payload[9],
1182 payload[10],
1183 payload[11],
1184 payload[12],
1185 payload[13],
1186 payload[14],
1187 payload[15],
1188 ]);
1189
1190 Ok(CommitResult {
1191 consumer_id,
1192 committed_offset,
1193 })
1194 },
1195 FrameType::Control(ControlCommand::ErrorResponse) => {
1196 let error_msg = frame
1197 .payload
1198 .map(|p| String::from_utf8_lossy(&p).to_string())
1199 .unwrap_or_else(|| "Unknown error".to_string());
1200 Err(ClientError::ServerError(error_msg))
1201 },
1202 other => Err(ClientError::InvalidResponse(format!(
1203 "Expected CommitAck, got {:?}",
1204 other
1205 ))),
1206 }
1207 }
1208
1209 fn parse_fetch_response(&self, frame: Frame) -> Result<FetchResult> {
1210 match frame.frame_type {
1211 FrameType::Control(ControlCommand::CatchingUp) => {
1212 let server_offset = frame
1213 .payload
1214 .as_ref()
1215 .filter(|p| p.len() >= 8)
1216 .map(|p| u64::from_le_bytes([p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7]]))
1217 .unwrap_or(0);
1218 Err(ClientError::ServerCatchingUp { server_offset })
1219 },
1220 FrameType::Control(ControlCommand::FetchResponse) => {
1221 let payload = frame.payload.ok_or_else(|| {
1222 ClientError::InvalidResponse("Empty fetch response".to_string())
1223 })?;
1224
1225 if payload.len() < 16 {
1226 return Err(ClientError::ProtocolError(
1227 "Fetch response too small".to_string(),
1228 ));
1229 }
1230
1231 let next_offset = u64::from_le_bytes([
1232 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5],
1233 payload[6], payload[7],
1234 ]);
1235 let bytes_returned =
1236 u32::from_le_bytes([payload[8], payload[9], payload[10], payload[11]]);
1237 let record_count =
1238 u32::from_le_bytes([payload[12], payload[13], payload[14], payload[15]]);
1239 let data = payload.slice(16..);
1240
1241 Ok(FetchResult {
1242 data,
1243 next_offset,
1244 bytes_returned,
1245 record_count,
1246 })
1247 },
1248 FrameType::Control(ControlCommand::ErrorResponse) => {
1249 let error_msg = frame
1250 .payload
1251 .map(|p| String::from_utf8_lossy(&p).to_string())
1252 .unwrap_or_else(|| "Unknown error".to_string());
1253 Err(ClientError::ServerError(error_msg))
1254 },
1255 other => Err(ClientError::InvalidResponse(format!(
1256 "Expected FetchResponse, got {:?}",
1257 other
1258 ))),
1259 }
1260 }
1261
1262 fn parse_delete_response(&self, frame: Frame) -> Result<()> {
1263 expect_success_response(frame)
1264 }
1265
1266 fn parse_retention_response(&self, frame: Frame) -> Result<()> {
1267 expect_success_response(frame)
1268 }
1269
1270 fn parse_topic_response(&self, frame: Frame) -> Result<TopicInfo> {
1271 match frame.frame_type {
1272 FrameType::Control(ControlCommand::TopicResponse) => {
1273 let payload = frame.payload.ok_or_else(|| {
1274 ClientError::InvalidResponse("Empty topic response".to_string())
1275 })?;
1276 let json: serde_json::Value = serde_json::from_slice(&payload)
1277 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
1278
1279 let retention = if json.get("retention").is_some() {
1280 Some(RetentionInfo {
1281 max_age_secs: json["retention"]["max_age_secs"].as_u64().unwrap_or(0),
1282 max_bytes: json["retention"]["max_bytes"].as_u64().unwrap_or(0),
1283 })
1284 } else {
1285 None
1286 };
1287
1288 Ok(TopicInfo {
1289 id: json["id"].as_u64().unwrap_or(0) as u32,
1290 name: json["name"].as_str().unwrap_or("").to_string(),
1291 created_at: json["created_at"].as_u64().unwrap_or(0),
1292 topic_epoch: json["topic_epoch"].as_u64().unwrap_or(1),
1293 retention,
1294 })
1295 },
1296 FrameType::Control(ControlCommand::ErrorResponse) => {
1297 let error_msg = frame
1298 .payload
1299 .map(|p| String::from_utf8_lossy(&p).to_string())
1300 .unwrap_or_else(|| "Unknown error".to_string());
1301 Err(ClientError::ServerError(error_msg))
1302 },
1303 other => Err(ClientError::InvalidResponse(format!(
1304 "Expected TopicResponse, got {:?}",
1305 other
1306 ))),
1307 }
1308 }
1309
1310 fn parse_topic_list_response(&self, frame: Frame) -> Result<Vec<TopicInfo>> {
1311 match frame.frame_type {
1312 FrameType::Control(ControlCommand::TopicResponse) => {
1313 let payload = frame.payload.ok_or_else(|| {
1314 ClientError::InvalidResponse("Empty topic list response".to_string())
1315 })?;
1316 let json: serde_json::Value = serde_json::from_slice(&payload)
1317 .map_err(|e| ClientError::ProtocolError(format!("Invalid JSON: {}", e)))?;
1318
1319 let topics = json["topics"]
1320 .as_array()
1321 .map(|arr| {
1322 arr.iter()
1323 .map(|t| {
1324 let retention = if t.get("retention").is_some() {
1325 Some(RetentionInfo {
1326 max_age_secs: t["retention"]["max_age_secs"]
1327 .as_u64()
1328 .unwrap_or(0),
1329 max_bytes: t["retention"]["max_bytes"]
1330 .as_u64()
1331 .unwrap_or(0),
1332 })
1333 } else {
1334 None
1335 };
1336 TopicInfo {
1337 id: t["id"].as_u64().unwrap_or(0) as u32,
1338 name: t["name"].as_str().unwrap_or("").to_string(),
1339 created_at: t["created_at"].as_u64().unwrap_or(0),
1340 topic_epoch: t["topic_epoch"].as_u64().unwrap_or(1),
1341 retention,
1342 }
1343 })
1344 .collect()
1345 })
1346 .unwrap_or_default();
1347
1348 Ok(topics)
1349 },
1350 FrameType::Control(ControlCommand::ErrorResponse) => {
1351 let error_msg = frame
1352 .payload
1353 .map(|p| String::from_utf8_lossy(&p).to_string())
1354 .unwrap_or_else(|| "Unknown error".to_string());
1355 Err(ClientError::ServerError(error_msg))
1356 },
1357 other => Err(ClientError::InvalidResponse(format!(
1358 "Expected TopicResponse, got {:?}",
1359 other
1360 ))),
1361 }
1362 }
1363
1364 async fn recv_frame(&mut self) -> Result<Frame> {
1376 const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
1378
1379 loop {
1380 if self.read_offset >= LWP_HEADER_SIZE {
1381 let payload_len = u32::from_le_bytes([
1383 self.read_buffer[32],
1384 self.read_buffer[33],
1385 self.read_buffer[34],
1386 self.read_buffer[35],
1387 ]) as usize;
1388 let total_frame_size = LWP_HEADER_SIZE + payload_len;
1389 if total_frame_size > MAX_FRAME_SIZE {
1390 return Err(ClientError::ServerError(format!(
1391 "Frame too large: {} bytes",
1392 total_frame_size
1393 )));
1394 }
1395 if total_frame_size > self.read_buffer.len() {
1396 self.read_buffer.resize(total_frame_size, 0);
1397 }
1398
1399 if let Some((frame, consumed)) = parse_frame(&self.read_buffer[..self.read_offset])?
1400 {
1401 self.read_buffer.copy_within(consumed..self.read_offset, 0);
1402 self.read_offset -= consumed;
1403 if self.read_buffer.len() > 64 * 1024 && self.read_offset < 64 * 1024 {
1405 self.read_buffer.resize(64 * 1024, 0);
1406 }
1407 return Ok(frame);
1408 }
1409 }
1410
1411 let n = tokio::time::timeout(
1412 self.config.read_timeout,
1413 self.stream.read(&mut self.read_buffer[self.read_offset..]),
1414 )
1415 .await
1416 .map_err(|_| ClientError::Timeout)??;
1417
1418 if n == 0 {
1419 return Err(ClientError::ConnectionClosed);
1420 }
1421
1422 self.read_offset += n;
1423 }
1424 }
1425
1426 pub fn config(&self) -> &ClientConfig {
1428 &self.config
1429 }
1430
1431 pub async fn close(mut self) -> Result<()> {
1433 self.stream.shutdown().await?;
1434 Ok(())
1435 }
1436}
1437
1438impl std::fmt::Debug for LanceClient {
1439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1440 f.debug_struct("LanceClient")
1441 .field("addr", &self.config.addr)
1442 .field("batch_id", &self.batch_id.load(Ordering::SeqCst))
1443 .finish()
1444 }
1445}