1use anyhow::{Context, Result, bail};
7use bytes::Bytes;
8use std::{path::Path, sync::Arc};
9use synapse_primitives::{InterfaceId, MethodId};
10use synapse_proto::RpcRequest;
11use synapse_proto::RpcStatus;
12#[cfg(feature = "otlp")]
13use synapse_proto::{HeaderEntry, header_entry};
14use synapse_rpc::HttpRpcClient;
15use tracing::{debug, warn};
16
17#[derive(Clone)]
25pub struct SynClient {
26 inner: Arc<SynClientInner>,
27}
28
29struct SynClientInner {
30 http_client: HttpRpcClient,
31}
32
33impl SynClient {
34 pub fn new(gateway_url: impl Into<String>) -> Self {
44 Self {
45 inner: Arc::new(SynClientInner {
46 http_client: HttpRpcClient::json(gateway_url),
47 }),
48 }
49 }
50
51 pub fn with_protobuf(gateway_url: impl Into<String>) -> Self {
53 Self {
54 inner: Arc::new(SynClientInner {
55 http_client: HttpRpcClient::protobuf(gateway_url),
56 }),
57 }
58 }
59
60 pub fn with_mtls(
62 gateway_url: impl Into<String>,
63 cert_path: impl AsRef<Path>,
64 key_path: impl AsRef<Path>,
65 ca_cert_path: impl AsRef<Path>,
66 ) -> Result<Self> {
67 Ok(Self {
68 inner: Arc::new(SynClientInner {
69 http_client: HttpRpcClient::protobuf_mtls(
70 gateway_url,
71 cert_path,
72 key_path,
73 ca_cert_path,
74 )?,
75 }),
76 })
77 }
78
79 pub fn with_timeout(self, timeout: std::time::Duration) -> Self {
81 Self {
82 inner: Arc::new(SynClientInner {
83 http_client: HttpRpcClient::new(
84 self.gateway_url(),
85 self.inner.http_client.content_type(),
86 )
87 .with_timeout(timeout),
88 }),
89 }
90 }
91
92 pub async fn call(
99 &self,
100 interface: impl Into<InterfaceId>,
101 method: impl Into<MethodId>,
102 payload: Bytes,
103 ) -> Result<Bytes> {
104 let interface_id = interface.into();
105 let method_id = method.into();
106
107 debug!(
108 "Calling {}.{} ({} bytes)",
109 u32::from(interface_id),
110 u32::from(method_id),
111 payload.len()
112 );
113
114 #[allow(unused_mut)]
116 let mut headers = Vec::new();
117
118 #[cfg(feature = "otlp")]
120 {
121 use opentelemetry::trace::TraceContextExt;
122 use tracing_opentelemetry::OpenTelemetrySpanExt;
123
124 let span = tracing::Span::current();
125 let otel_ctx = span.context();
126 let span_ref = otel_ctx.span();
127 let span_ctx = span_ref.span_context();
128
129 if span_ctx.is_valid() {
130 headers.push(HeaderEntry {
131 key: u32::from(*synapse_primitives::id::well_known::TRACE_ID),
132 value: Some(header_entry::Value::StringValue(format!(
133 "{:032x}",
134 span_ctx.trace_id()
135 ))),
136 });
137 headers.push(HeaderEntry {
138 key: u32::from(*synapse_primitives::id::well_known::SPAN_ID),
139 value: Some(header_entry::Value::StringValue(format!(
140 "{:016x}",
141 span_ctx.span_id()
142 ))),
143 });
144 }
145 }
146
147 let request = RpcRequest {
148 interface_id: interface_id.into(),
149 method_id: method_id.into(),
150 headers,
151 payload,
152 sent_at_unix_ms: chrono::Utc::now().timestamp_millis(),
153 };
154
155 let response = self
157 .inner
158 .http_client
159 .call(request)
160 .await
161 .context("RPC call failed")?;
162
163 if response.status != RpcStatus::Ok as i32 {
165 let error = response.error.unwrap_or_else(|| synapse_proto::RpcError {
166 code: response.status as u32,
167 message: format!("RPC failed with status {}", response.status),
168 details: vec![],
169 });
170
171 warn!("RPC call failed: {} - {}", error.code, error.message);
172 bail!("RPC error {}: {}", error.code, error.message);
173 }
174
175 debug!("RPC call succeeded ({} bytes)", response.payload.len());
176 Ok(response.payload)
177 }
178
179 pub async fn call_json<TReq, TResp>(
183 &self,
184 interface: impl Into<InterfaceId>,
185 method: impl Into<MethodId>,
186 request: &TReq,
187 ) -> Result<TResp>
188 where
189 TReq: serde::Serialize,
190 TResp: serde::de::DeserializeOwned,
191 {
192 let payload = serde_json::to_vec(request).context("Failed to serialize request")?;
194
195 let response_bytes = self.call(interface, method, Bytes::from(payload)).await?;
197
198 let response =
200 serde_json::from_slice(&response_bytes).context("Failed to deserialize response")?;
201
202 Ok(response)
203 }
204
205 pub async fn call_proto<TReq, TResp>(
210 &self,
211 interface: impl Into<InterfaceId>,
212 method: impl Into<MethodId>,
213 request: &TReq,
214 ) -> Result<TResp>
215 where
216 TReq: prost::Message,
217 TResp: prost::Message + Default,
218 {
219 let payload = request.encode_to_vec();
221
222 let response_bytes = self.call(interface, method, Bytes::from(payload)).await?;
224
225 let response =
227 TResp::decode(response_bytes.as_ref()).context("Failed to deserialize response")?;
228
229 Ok(response)
230 }
231
232 pub fn gateway_url(&self) -> &str {
234 self.inner.http_client.gateway_url()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_client_creation() {
244 let client = SynClient::new("http://localhost:8080");
245 assert_eq!(client.gateway_url(), "http://localhost:8080");
246
247 let _client =
249 SynClient::new("http://gateway:5000").with_timeout(std::time::Duration::from_secs(60));
250 }
251
252 #[test]
253 fn test_client_clone() {
254 let client1 = SynClient::new("http://localhost:8080");
255 let client2 = client1.clone();
256
257 assert_eq!(client1.gateway_url(), client2.gateway_url());
259 }
260
261 #[tokio::test]
262 #[ignore] async fn test_call() {
264 let client = SynClient::new("http://localhost:8080");
265
266 let interface = InterfaceId::from_name("test.Service");
267 let method = MethodId::from_name("Echo");
268 let payload = Bytes::from("hello");
269
270 let _response = client.call(interface, method, payload).await;
271 }
272}