1use {
2 bytes::Bytes,
3 futures::{StreamExt, stream::SplitSink},
4 std::{
5 sync::{
6 Arc,
7 atomic::{AtomicU64, Ordering},
8 },
9 time::Duration,
10 },
11 tokio::{
12 net::TcpStream,
13 sync::{mpsc, oneshot},
14 time::timeout,
15 },
16 tokio_util::codec::Framed,
17 tracing::{info, instrument},
18};
19
20use crate::{
21 codec::RpcCodec,
22 error::{Result, ZusError},
23 protocol::RpcMessage,
24};
25
26struct WriterRequest {
28 msg: RpcMessage,
29 response_tx: Option<oneshot::Sender<Bytes>>,
30}
31
32struct PendingRequest {
34 response_tx: oneshot::Sender<Bytes>,
35}
36
37async fn writer_task(
39 mut writer: SplitSink<Framed<TcpStream, RpcCodec>, RpcMessage>,
40 mut rx: mpsc::UnboundedReceiver<WriterRequest>,
41 pending_requests: Arc<dashmap::DashMap<u64, PendingRequest>>,
42) {
43 use futures::SinkExt;
44
45 while let Some(request) = rx.recv().await {
46 let sequence = request.msg.header.sequence;
47
48 if let Some(response_tx) = request.response_tx {
50 pending_requests.insert(sequence, PendingRequest { response_tx });
51 }
52
53 if let Err(e) = writer.send(request.msg).await {
55 tracing::error!("Writer task send error: {:?}", e);
56 pending_requests.remove(&sequence);
58 break;
59 }
60 }
61
62 tracing::info!("Writer task shutting down");
63}
64
65async fn reader_task(
67 mut reader: futures::stream::SplitStream<Framed<TcpStream, RpcCodec>>,
68 pending_requests: Arc<dashmap::DashMap<u64, PendingRequest>>,
69) {
70 while let Some(result) = reader.next().await {
71 match result {
72 | Ok(msg) => {
73 if msg.header.msg_type == zus_proto::constants::MSG_TYPE_RSP
76 || msg.header.msg_type == zus_proto::constants::MSG_TYPE_SYSRSP
77 {
78 let sequence = msg.header.sequence;
79 if let Some((_, pending)) = pending_requests.remove(&sequence) {
80 if pending.response_tx.send(msg.body).is_err() {
81 tracing::warn!("Failed to send response for sequence {}", sequence);
82 }
83 } else {
84 tracing::warn!("No pending request for sequence {}", sequence);
85 }
86 } else {
87 tracing::warn!("Unexpected message type: {}", msg.header.msg_type);
88 }
89 }
90 | Err(e) => {
91 tracing::error!("Reader task error: {:?}", e);
92 break;
93 }
94 }
95 }
96
97 tracing::info!("Reader task shutting down");
98}
99
100#[derive(Clone)]
102pub struct RpcEndpoint {
103 host: String,
104 port: u16,
105 writer_tx: mpsc::UnboundedSender<WriterRequest>,
106 sequence: Arc<AtomicU64>,
107}
108
109impl RpcEndpoint {
110 #[instrument(name = "rpc_connect", skip_all, fields(host = %host, port = port))]
112 pub async fn connect(host: String, port: u16) -> Result<Self> {
113 let addr = format!("{host}:{port}");
114 info!("Connecting to {}", addr);
115
116 let stream = TcpStream::connect(&addr).await?;
117 let framed = Framed::new(stream, RpcCodec::new());
118
119 let (writer, reader) = framed.split();
121
122 let pending_requests = Arc::new(dashmap::DashMap::new());
124
125 let (writer_tx, writer_rx) = mpsc::unbounded_channel();
127
128 tokio::spawn(writer_task(writer, writer_rx, pending_requests.clone()));
130
131 tokio::spawn(reader_task(reader, pending_requests));
133
134 Ok(Self {
135 host,
136 port,
137 writer_tx,
138 sequence: Arc::new(AtomicU64::new(1)),
139 })
140 }
141
142 #[instrument(name = "rpc_sync_call", skip(self, body), fields(endpoint = %self.address(), timeout_ms = timeout_ms))]
144 pub async fn sync_call(&self, method: Bytes, body: Bytes, timeout_ms: u64) -> Result<Bytes> {
145 let sequence = self.sequence.fetch_add(1, Ordering::SeqCst);
147
148 let msg = RpcMessage::new_request(sequence, method, body);
150
151 let (response_tx, response_rx) = oneshot::channel();
153
154 self
156 .writer_tx
157 .send(WriterRequest {
158 msg,
159 response_tx: Some(response_tx),
160 })
161 .map_err(|_| ZusError::Connection("Writer channel closed".to_string()))?;
162
163 match timeout(Duration::from_millis(timeout_ms), response_rx).await {
165 | Ok(Ok(response)) => Ok(response),
166 | Ok(Err(_)) => Err(ZusError::Connection("Response channel closed".to_string())),
167 | Err(_) => Err(ZusError::Timeout),
168 }
169 }
170
171 #[instrument(name = "rpc_notify_call", skip(self, body), fields(endpoint = %self.address()))]
173 pub async fn notify_call(&self, method: Bytes, body: Bytes) -> Result<()> {
174 let sequence = self.sequence.fetch_add(1, Ordering::SeqCst);
176
177 let msg = RpcMessage::new_notify(sequence, method, body);
179
180 self
182 .writer_tx
183 .send(WriterRequest { msg, response_tx: None })
184 .map_err(|_| ZusError::Connection("Writer channel closed".to_string()))?;
185
186 Ok(())
187 }
188
189 pub fn address(&self) -> String {
191 format!("{}:{}", self.host, self.port)
192 }
193
194 pub fn host(&self) -> &str {
195 &self.host
196 }
197
198 pub fn port(&self) -> u16 {
199 self.port
200 }
201}
202
203#[cfg(test)]
204mod tests {
205
206 #[tokio::test]
207 async fn test_endpoint_address() {
208 let host = "localhost".to_string();
211 let port = 9527u16;
212 let addr = format!("{host}:{port}");
213 assert_eq!(addr, "localhost:9527");
214 }
215}