zus_common/
endpoint.rs

1use {
2  bytes::Bytes,
3  futures::{StreamExt, stream::SplitSink},
4  std::{
5    sync::{
6      Arc,
7      atomic::{AtomicU64, Ordering},
8    },
9    time::Duration,
10  },
11  tokio::{
12    net::TcpStream,
13    sync::{mpsc, oneshot},
14    time::timeout,
15  },
16  tokio_util::codec::Framed,
17  tracing::{info, instrument},
18};
19
20use crate::{
21  codec::RpcCodec,
22  error::{Result, ZusError},
23  protocol::RpcMessage,
24};
25
26/// Request sent to the writer task (includes response callback for sync calls)
27struct WriterRequest {
28  msg: RpcMessage,
29  response_tx: Option<oneshot::Sender<Bytes>>,
30}
31
32/// Response routing info stored by sequence number
33struct PendingRequest {
34  response_tx: oneshot::Sender<Bytes>,
35}
36
37/// Writer task that owns the SplitSink and receives messages via channel
38async fn writer_task(
39  mut writer: SplitSink<Framed<TcpStream, RpcCodec>, RpcMessage>,
40  mut rx: mpsc::UnboundedReceiver<WriterRequest>,
41  pending_requests: Arc<dashmap::DashMap<u64, PendingRequest>>,
42) {
43  use futures::SinkExt;
44
45  while let Some(request) = rx.recv().await {
46    let sequence = request.msg.header.sequence;
47
48    // Register response callback if this is a sync call
49    if let Some(response_tx) = request.response_tx {
50      pending_requests.insert(sequence, PendingRequest { response_tx });
51    }
52
53    // Send the message
54    if let Err(e) = writer.send(request.msg).await {
55      tracing::error!("Writer task send error: {:?}", e);
56      // Clean up pending request on error
57      pending_requests.remove(&sequence);
58      break;
59    }
60  }
61
62  tracing::info!("Writer task shutting down");
63}
64
65/// Reader task that routes responses to the appropriate oneshot channels
66async fn reader_task(
67  mut reader: futures::stream::SplitStream<Framed<TcpStream, RpcCodec>>,
68  pending_requests: Arc<dashmap::DashMap<u64, PendingRequest>>,
69) {
70  while let Some(result) = reader.next().await {
71    match result {
72      | Ok(msg) => {
73        // Handle both regular responses (RSP) and system responses (SYSRSP)
74        // C++ ZooServer may send SYSRSP (type 3) for certain operations like sync/heartbeat
75        if msg.header.msg_type == zus_proto::constants::MSG_TYPE_RSP
76          || msg.header.msg_type == zus_proto::constants::MSG_TYPE_SYSRSP
77        {
78          let sequence = msg.header.sequence;
79          if let Some((_, pending)) = pending_requests.remove(&sequence) {
80            if pending.response_tx.send(msg.body).is_err() {
81              tracing::warn!("Failed to send response for sequence {}", sequence);
82            }
83          } else {
84            tracing::warn!("No pending request for sequence {}", sequence);
85          }
86        } else {
87          tracing::warn!("Unexpected message type: {}", msg.header.msg_type);
88        }
89      }
90      | Err(e) => {
91        tracing::error!("Reader task error: {:?}", e);
92        break;
93      }
94    }
95  }
96
97  tracing::info!("Reader task shutting down");
98}
99
100/// RPC Endpoint (replacing Java's RpcEndPoint)
101#[derive(Clone)]
102pub struct RpcEndpoint {
103  host: String,
104  port: u16,
105  writer_tx: mpsc::UnboundedSender<WriterRequest>,
106  sequence: Arc<AtomicU64>,
107}
108
109impl RpcEndpoint {
110  /// Create and connect to the remote endpoint
111  #[instrument(name = "rpc_connect", skip_all, fields(host = %host, port = port))]
112  pub async fn connect(host: String, port: u16) -> Result<Self> {
113    let addr = format!("{host}:{port}");
114    info!("Connecting to {}", addr);
115
116    let stream = TcpStream::connect(&addr).await?;
117    let framed = Framed::new(stream, RpcCodec::new());
118
119    // Split the stream into read and write halves
120    let (writer, reader) = framed.split();
121
122    // Shared pending requests map for correlating responses
123    let pending_requests = Arc::new(dashmap::DashMap::new());
124
125    // Create channel for writer task
126    let (writer_tx, writer_rx) = mpsc::unbounded_channel();
127
128    // Spawn writer task
129    tokio::spawn(writer_task(writer, writer_rx, pending_requests.clone()));
130
131    // Spawn reader task
132    tokio::spawn(reader_task(reader, pending_requests));
133
134    Ok(Self {
135      host,
136      port,
137      writer_tx,
138      sequence: Arc::new(AtomicU64::new(1)),
139    })
140  }
141
142  /// Synchronous RPC call
143  #[instrument(name = "rpc_sync_call", skip(self, body), fields(endpoint = %self.address(), timeout_ms = timeout_ms))]
144  pub async fn sync_call(&self, method: Bytes, body: Bytes, timeout_ms: u64) -> Result<Bytes> {
145    // Generate sequence number
146    let sequence = self.sequence.fetch_add(1, Ordering::SeqCst);
147
148    // Create the request message
149    let msg = RpcMessage::new_request(sequence, method, body);
150
151    // Create oneshot channel for response
152    let (response_tx, response_rx) = oneshot::channel();
153
154    // Send via channel with response callback
155    self
156      .writer_tx
157      .send(WriterRequest {
158        msg,
159        response_tx: Some(response_tx),
160      })
161      .map_err(|_| ZusError::Connection("Writer channel closed".to_string()))?;
162
163    // Wait for response with timeout
164    match timeout(Duration::from_millis(timeout_ms), response_rx).await {
165      | Ok(Ok(response)) => Ok(response),
166      | Ok(Err(_)) => Err(ZusError::Connection("Response channel closed".to_string())),
167      | Err(_) => Err(ZusError::Timeout),
168    }
169  }
170
171  /// Asynchronous RPC call (fire and forget)
172  #[instrument(name = "rpc_notify_call", skip(self, body), fields(endpoint = %self.address()))]
173  pub async fn notify_call(&self, method: Bytes, body: Bytes) -> Result<()> {
174    // Generate sequence number
175    let sequence = self.sequence.fetch_add(1, Ordering::SeqCst);
176
177    // Create message
178    let msg = RpcMessage::new_notify(sequence, method, body);
179
180    // Send via channel (no response callback)
181    self
182      .writer_tx
183      .send(WriterRequest { msg, response_tx: None })
184      .map_err(|_| ZusError::Connection("Writer channel closed".to_string()))?;
185
186    Ok(())
187  }
188
189  /// Get endpoint address
190  pub fn address(&self) -> String {
191    format!("{}:{}", self.host, self.port)
192  }
193
194  pub fn host(&self) -> &str {
195    &self.host
196  }
197
198  pub fn port(&self) -> u16 {
199    self.port
200  }
201}
202
203#[cfg(test)]
204mod tests {
205
206  #[tokio::test]
207  async fn test_endpoint_address() {
208    // Can't actually connect in unit tests without a server,
209    // but we can test the address formatting
210    let host = "localhost".to_string();
211    let port = 9527u16;
212    let addr = format!("{host}:{port}");
213    assert_eq!(addr, "localhost:9527");
214  }
215}