1use std::default::Default;
2use std::fmt;
3use std::fmt::{Debug, Display};
4use std::ops::Deref;
5use std::sync::Arc;
6use std::time::Duration;
7
8use fluvio_protocol::Version;
9use tracing::{debug, instrument, info};
10
11use fluvio_protocol::api::RequestMessage;
12use fluvio_protocol::api::Request;
13use fluvio_protocol::link::versions::{ApiVersions, ApiVersionsRequest, ApiVersionsResponse};
14use fluvio_future::net::{DomainConnector, DefaultDomainConnector};
15use fluvio_future::retry::retry_if;
16
17use crate::{SocketError, FluvioSocket, SharedMultiplexerSocket, AsyncResponse};
18
19pub trait SerialFrame: Display {
21 fn config(&self) -> &ClientConfig;
23}
24
25pub struct VersionedSocket {
28 socket: FluvioSocket,
29 config: Arc<ClientConfig>,
30 versions: Versions,
31}
32
33impl fmt::Display for VersionedSocket {
34 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35 write!(f, "config {}", self.config)
36 }
37}
38
39impl SerialFrame for VersionedSocket {
40 fn config(&self) -> &ClientConfig {
41 &self.config
42 }
43}
44
45impl VersionedSocket {
46 #[instrument(skip(socket, config))]
48 pub async fn connect(
49 mut socket: FluvioSocket,
50 config: Arc<ClientConfig>,
51 ) -> Result<Self, SocketError> {
52 let version = ApiVersionsRequest {
56 client_version: crate::built_info::PKG_VERSION.into(),
57 client_os: crate::built_info::CFG_OS.into(),
58 client_arch: crate::built_info::CFG_TARGET_ARCH.into(),
59 };
60
61 debug!(client_version = %version.client_version, "querying versions");
62 let mut req_msg = RequestMessage::new_request(version);
63 req_msg.get_mut_header().set_client_id(&config.client_id);
64
65 let response: ApiVersionsResponse = (socket.send(&req_msg).await?).response;
66 let versions = Versions::new(response);
67
68 debug!("versions: {:#?}", versions);
69
70 Ok(Self {
71 socket,
72 config,
73 versions,
74 })
75 }
76
77 pub fn split(self) -> (FluvioSocket, Arc<ClientConfig>, Versions) {
78 (self.socket, self.config, self.versions)
79 }
80}
81
82pub struct ClientConfig {
85 addr: String,
86 client_id: String,
87 connector: DomainConnector,
88 use_spu_local_address: bool,
89}
90
91impl Debug for ClientConfig {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 write!(
94 f,
95 "ClientConfig {{ addr: {}, client_id: {} }}",
96 self.addr, self.client_id
97 )
98 }
99}
100
101impl fmt::Display for ClientConfig {
102 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103 write!(f, "addr {}", self.addr)
104 }
105}
106
107impl From<String> for ClientConfig {
108 fn from(addr: String) -> Self {
109 Self::with_addr(addr)
110 }
111}
112
113impl ClientConfig {
114 pub fn new(
115 addr: impl Into<String>,
116 connector: DomainConnector,
117 use_spu_local_address: bool,
118 ) -> Self {
119 Self {
120 addr: addr.into(),
121 client_id: "fluvio".to_owned(),
122 connector,
123 use_spu_local_address,
124 }
125 }
126
127 #[allow(clippy::box_default)]
128 pub fn with_addr(addr: String) -> Self {
129 Self::new(addr, Box::new(DefaultDomainConnector::default()), false)
130 }
131
132 pub fn addr(&self) -> &str {
133 &self.addr
134 }
135
136 pub fn client_id(&self) -> &str {
137 &self.client_id
138 }
139
140 pub fn use_spu_local_address(&self) -> bool {
141 self.use_spu_local_address
142 }
143
144 pub fn connector(&self) -> &DomainConnector {
145 &self.connector
146 }
147
148 pub fn set_client_id(&mut self, id: impl Into<String>) {
149 self.client_id = id.into();
150 }
151
152 pub fn set_addr(&mut self, domain: String) {
153 self.addr = domain
154 }
155
156 #[instrument(skip(self))]
157 pub async fn connect(self) -> Result<VersionedSocket, SocketError> {
158 debug!(add = %self.addr, "try connection to");
159 let socket =
160 FluvioSocket::connect_with_connector(&self.addr, self.connector.as_ref()).await?;
161 info!(add = %self.addr, "connect to socket");
162 VersionedSocket::connect(socket, Arc::new(self)).await
163 }
164
165 #[instrument(skip(self))]
167 pub fn with_prefix_sni_domain(&self, prefix: &str) -> Self {
168 let new_domain = format!("{}.{}", prefix, self.connector.domain());
169 debug!(sni_domain = %new_domain);
170 let connector = self.connector.new_domain(new_domain);
171
172 Self {
173 addr: self.addr.clone(),
174 client_id: self.client_id.clone(),
175 connector,
176 use_spu_local_address: self.use_spu_local_address,
177 }
178 }
179
180 pub fn recreate(&self) -> Self {
181 Self {
182 addr: self.addr.clone(),
183 client_id: self.client_id.clone(),
184 connector: self
185 .connector
186 .new_domain(self.connector.domain().to_owned()),
187 use_spu_local_address: self.use_spu_local_address,
188 }
189 }
190}
191
192#[derive(Clone, Debug)]
194pub struct Versions {
195 api_versions: ApiVersions,
196 platform_version: semver::Version,
197}
198
199impl Versions {
200 pub fn new(version_response: ApiVersionsResponse) -> Self {
201 Self {
202 api_versions: version_response.api_keys,
203 platform_version: version_response.platform_version.to_semver(),
204 }
205 }
206
207 pub fn platform_version(&self) -> &semver::Version {
212 &self.platform_version
213 }
214
215 pub fn lookup_version<R: Request>(&self) -> Option<i16> {
217 for version in &self.api_versions {
218 if version.api_key == R::API_KEY as i16 {
219 if version.max_version >= R::MIN_API_VERSION
221 && version.min_version <= R::MAX_API_VERSION
222 {
223 return Some(R::MAX_API_VERSION.min(version.max_version));
224 }
225 }
226 }
227
228 None
229 }
230}
231
232pub struct VersionedSerialSocket {
234 socket: SharedMultiplexerSocket,
235 config: Arc<ClientConfig>,
236 versions: Versions,
237}
238
239impl Deref for VersionedSerialSocket {
240 type Target = SharedMultiplexerSocket;
241
242 fn deref(&self) -> &Self::Target {
243 &self.socket
244 }
245}
246
247impl fmt::Display for VersionedSerialSocket {
248 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
249 write!(f, "config: {}, {:?}", self.config, self.socket)
250 }
251}
252unsafe impl Send for VersionedSerialSocket {}
253
254impl VersionedSerialSocket {
255 pub fn new(
256 socket: SharedMultiplexerSocket,
257 config: Arc<ClientConfig>,
258 versions: Versions,
259 ) -> Self {
260 Self {
261 socket,
262 config,
263 versions,
264 }
265 }
266
267 pub fn versions(&self) -> &Versions {
268 &self.versions
269 }
270
271 pub fn new_socket(&self) -> SharedMultiplexerSocket {
273 self.socket.clone()
274 }
275
276 pub fn is_stale(&self) -> bool {
278 self.socket.is_stale()
279 }
280
281 fn check_liveness(&self) -> Result<(), SocketError> {
282 if self.is_stale() {
283 Err(SocketError::SocketStale)
284 } else {
285 Ok(())
286 }
287 }
288
289 #[instrument(level = "trace", skip(self, request))]
291 pub async fn send_receive<R>(&self, request: R) -> Result<R::Response, SocketError>
292 where
293 R: Request + Send + Sync,
294 {
295 self.check_liveness()?;
296
297 let req_msg = self.new_request(request, self.versions.lookup_version::<R>());
298
299 self.socket.send_and_receive(req_msg).await
301 }
302
303 #[instrument(level = "trace", skip(self, request))]
305 pub async fn send_async<R>(&self, request: R) -> Result<AsyncResponse<R>, SocketError>
306 where
307 R: Request + Send + Sync,
308 {
309 self.check_liveness()?;
310
311 let req_msg = self.new_request(request, self.lookup_version::<R>());
312
313 self.socket.send_async(req_msg).await
315 }
316
317 pub fn lookup_version<R>(&self) -> Option<Version>
319 where
320 R: Request,
321 {
322 self.versions.lookup_version::<R>()
323 }
324
325 #[instrument(level = "trace", skip(self, request))]
327 pub async fn send_receive_with_retry<R, I>(
328 &self,
329 request: R,
330 retries: I,
331 ) -> Result<R::Response, SocketError>
332 where
333 R: Request + Send + Sync + Clone,
334 I: IntoIterator<Item = Duration> + Debug + Send,
335 {
336 self.check_liveness()?;
337
338 let req_msg = self.new_request(request, self.versions.lookup_version::<R>());
339
340 retry_if(
342 retries,
343 || self.socket.send_and_receive(req_msg.clone()),
344 is_retryable,
345 )
346 .await
347 }
348
349 #[instrument(level = "trace", skip(self, request, version))]
351 pub fn new_request<R>(&self, request: R, version: Option<i16>) -> RequestMessage<R>
352 where
353 R: Request + Send,
354 {
355 let mut req_msg = RequestMessage::new_request(request);
356 req_msg
357 .get_mut_header()
358 .set_client_id(&self.config().client_id);
359
360 if let Some(ver) = version {
361 req_msg.get_mut_header().set_api_version(ver);
362 }
363 req_msg
364 }
365}
366
367impl SerialFrame for VersionedSerialSocket {
368 fn config(&self) -> &ClientConfig {
369 &self.config
370 }
371}
372
373fn is_retryable(err: &SocketError) -> bool {
374 use std::io::ErrorKind;
375 match err {
376 SocketError::Io { source, .. } => matches!(
377 source.kind(),
378 ErrorKind::AddrNotAvailable
379 | ErrorKind::ConnectionAborted
380 | ErrorKind::ConnectionRefused
381 | ErrorKind::ConnectionReset
382 | ErrorKind::NotConnected
383 | ErrorKind::Other
384 | ErrorKind::TimedOut
385 | ErrorKind::UnexpectedEof
386 | ErrorKind::Interrupted
387 ),
388
389 SocketError::SocketClosed | SocketError::SocketStale => false,
390 }
391}
392
393#[cfg(test)]
394mod test {
395 use fluvio_protocol::Decoder;
396 use fluvio_protocol::Encoder;
397 use fluvio_protocol::api::Request;
398 use fluvio_protocol::link::versions::ApiVersionKey;
399
400 use super::ApiVersionsResponse;
401 use super::Versions;
402
403 #[derive(Encoder, Decoder, Default, Debug)]
404 struct T1;
405
406 impl Request for T1 {
407 const API_KEY: u16 = 1000;
408 const MIN_API_VERSION: i16 = 6;
409 const MAX_API_VERSION: i16 = 9;
410
411 type Response = u8;
412 }
413
414 #[derive(Encoder, Decoder, Default, Debug)]
415 struct T2;
416
417 impl Request for T2 {
418 const API_KEY: u16 = 1000;
419 const MIN_API_VERSION: i16 = 2;
420 const MAX_API_VERSION: i16 = 3;
421
422 type Response = u8;
423 }
424
425 #[test]
426 fn test_version_lookup() {
427 let mut response = ApiVersionsResponse::default();
428
429 response.api_keys.push(ApiVersionKey {
430 api_key: 1000,
431 min_version: 5,
432 max_version: 10,
433 });
434
435 let versions = Versions::new(response);
436
437 assert_eq!(versions.lookup_version::<T1>(), Some(9));
439 assert_eq!(versions.lookup_version::<T2>(), None);
440 }
441}