fluvio_socket/
versioned.rs

1use std::default::Default;
2use std::fmt;
3use std::fmt::{Debug, Display};
4use std::ops::Deref;
5use std::sync::Arc;
6use std::time::Duration;
7
8use fluvio_protocol::Version;
9use tracing::{debug, instrument, info};
10
11use fluvio_protocol::api::RequestMessage;
12use fluvio_protocol::api::Request;
13use fluvio_protocol::link::versions::{ApiVersions, ApiVersionsRequest, ApiVersionsResponse};
14use fluvio_future::net::{DomainConnector, DefaultDomainConnector};
15use fluvio_future::retry::retry_if;
16
17use crate::{SocketError, FluvioSocket, SharedMultiplexerSocket, AsyncResponse};
18
19/// Frame with request and response
20pub trait SerialFrame: Display {
21    /// client config
22    fn config(&self) -> &ClientConfig;
23}
24
25/// This sockets knows about support versions
26/// Version information are automatically  insert into request
27pub struct VersionedSocket {
28    socket: FluvioSocket,
29    config: Arc<ClientConfig>,
30    versions: Versions,
31}
32
33impl fmt::Display for VersionedSocket {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        write!(f, "config {}", self.config)
36    }
37}
38
39impl SerialFrame for VersionedSocket {
40    fn config(&self) -> &ClientConfig {
41        &self.config
42    }
43}
44
45impl VersionedSocket {
46    /// connect to end point and retrieve versions
47    #[instrument(skip(socket, config))]
48    pub async fn connect(
49        mut socket: FluvioSocket,
50        config: Arc<ClientConfig>,
51    ) -> Result<Self, SocketError> {
52        // now get versions
53        // Query for API versions
54
55        let version = ApiVersionsRequest {
56            client_version: crate::built_info::PKG_VERSION.into(),
57            client_os: crate::built_info::CFG_OS.into(),
58            client_arch: crate::built_info::CFG_TARGET_ARCH.into(),
59        };
60
61        debug!(client_version = %version.client_version, "querying versions");
62        let mut req_msg = RequestMessage::new_request(version);
63        req_msg.get_mut_header().set_client_id(&config.client_id);
64
65        let response: ApiVersionsResponse = (socket.send(&req_msg).await?).response;
66        let versions = Versions::new(response);
67
68        debug!("versions: {:#?}", versions);
69
70        Ok(Self {
71            socket,
72            config,
73            versions,
74        })
75    }
76
77    pub fn split(self) -> (FluvioSocket, Arc<ClientConfig>, Versions) {
78        (self.socket, self.config, self.versions)
79    }
80}
81
82/// Low level configuration option to directly connect to Fluvio
83/// This can bypass higher level validation required for CLI and end user application
84pub struct ClientConfig {
85    addr: String,
86    client_id: String,
87    connector: DomainConnector,
88    use_spu_local_address: bool,
89}
90
91impl Debug for ClientConfig {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        write!(
94            f,
95            "ClientConfig {{ addr: {}, client_id: {} }}",
96            self.addr, self.client_id
97        )
98    }
99}
100
101impl fmt::Display for ClientConfig {
102    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103        write!(f, "addr {}", self.addr)
104    }
105}
106
107impl From<String> for ClientConfig {
108    fn from(addr: String) -> Self {
109        Self::with_addr(addr)
110    }
111}
112
113impl ClientConfig {
114    pub fn new(
115        addr: impl Into<String>,
116        connector: DomainConnector,
117        use_spu_local_address: bool,
118    ) -> Self {
119        Self {
120            addr: addr.into(),
121            client_id: "fluvio".to_owned(),
122            connector,
123            use_spu_local_address,
124        }
125    }
126
127    #[allow(clippy::box_default)]
128    pub fn with_addr(addr: String) -> Self {
129        Self::new(addr, Box::new(DefaultDomainConnector::default()), false)
130    }
131
132    pub fn addr(&self) -> &str {
133        &self.addr
134    }
135
136    pub fn client_id(&self) -> &str {
137        &self.client_id
138    }
139
140    pub fn use_spu_local_address(&self) -> bool {
141        self.use_spu_local_address
142    }
143
144    pub fn connector(&self) -> &DomainConnector {
145        &self.connector
146    }
147
148    pub fn set_client_id(&mut self, id: impl Into<String>) {
149        self.client_id = id.into();
150    }
151
152    pub fn set_addr(&mut self, domain: String) {
153        self.addr = domain
154    }
155
156    #[instrument(skip(self))]
157    pub async fn connect(self) -> Result<VersionedSocket, SocketError> {
158        debug!(add = %self.addr, "try connection to");
159        let socket =
160            FluvioSocket::connect_with_connector(&self.addr, self.connector.as_ref()).await?;
161        info!(add = %self.addr, "connect to socket");
162        VersionedSocket::connect(socket, Arc::new(self)).await
163    }
164
165    /// create new config with prefix add to domain, this is useful for SNI
166    #[instrument(skip(self))]
167    pub fn with_prefix_sni_domain(&self, prefix: &str) -> Self {
168        let new_domain = format!("{}.{}", prefix, self.connector.domain());
169        debug!(sni_domain = %new_domain);
170        let connector = self.connector.new_domain(new_domain);
171
172        Self {
173            addr: self.addr.clone(),
174            client_id: self.client_id.clone(),
175            connector,
176            use_spu_local_address: self.use_spu_local_address,
177        }
178    }
179
180    pub fn recreate(&self) -> Self {
181        Self {
182            addr: self.addr.clone(),
183            client_id: self.client_id.clone(),
184            connector: self
185                .connector
186                .new_domain(self.connector.domain().to_owned()),
187            use_spu_local_address: self.use_spu_local_address,
188        }
189    }
190}
191
192/// wrap around versions
193#[derive(Clone, Debug)]
194pub struct Versions {
195    api_versions: ApiVersions,
196    platform_version: semver::Version,
197}
198
199impl Versions {
200    pub fn new(version_response: ApiVersionsResponse) -> Self {
201        Self {
202            api_versions: version_response.api_keys,
203            platform_version: version_response.platform_version.to_semver(),
204        }
205    }
206
207    /// Tells the platform version reported by the SC
208    ///
209    /// The platform version refers to the value in the VERSION
210    /// file at the time the SC was compiled.
211    pub fn platform_version(&self) -> &semver::Version {
212        &self.platform_version
213    }
214
215    /// Given an API key, it returns maximum compatible version. None if not found
216    pub fn lookup_version<R: Request>(&self) -> Option<i16> {
217        for version in &self.api_versions {
218            if version.api_key == R::API_KEY as i16 {
219                // try to find most latest maximum version
220                if version.max_version >= R::MIN_API_VERSION
221                    && version.min_version <= R::MAX_API_VERSION
222                {
223                    return Some(R::MAX_API_VERSION.min(version.max_version));
224                }
225            }
226        }
227
228        None
229    }
230}
231
232/// Connection that perform request/response
233pub struct VersionedSerialSocket {
234    socket: SharedMultiplexerSocket,
235    config: Arc<ClientConfig>,
236    versions: Versions,
237}
238
239impl Deref for VersionedSerialSocket {
240    type Target = SharedMultiplexerSocket;
241
242    fn deref(&self) -> &Self::Target {
243        &self.socket
244    }
245}
246
247impl fmt::Display for VersionedSerialSocket {
248    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
249        write!(f, "config: {}, {:?}", self.config, self.socket)
250    }
251}
252unsafe impl Send for VersionedSerialSocket {}
253
254impl VersionedSerialSocket {
255    pub fn new(
256        socket: SharedMultiplexerSocket,
257        config: Arc<ClientConfig>,
258        versions: Versions,
259    ) -> Self {
260        Self {
261            socket,
262            config,
263            versions,
264        }
265    }
266
267    pub fn versions(&self) -> &Versions {
268        &self.versions
269    }
270
271    /// get new socket
272    pub fn new_socket(&self) -> SharedMultiplexerSocket {
273        self.socket.clone()
274    }
275
276    /// Check if inner socket is stale
277    pub fn is_stale(&self) -> bool {
278        self.socket.is_stale()
279    }
280
281    fn check_liveness(&self) -> Result<(), SocketError> {
282        if self.is_stale() {
283            Err(SocketError::SocketStale)
284        } else {
285            Ok(())
286        }
287    }
288
289    /// send and wait for reply serially
290    #[instrument(level = "trace", skip(self, request))]
291    pub async fn send_receive<R>(&self, request: R) -> Result<R::Response, SocketError>
292    where
293        R: Request + Send + Sync,
294    {
295        self.check_liveness()?;
296
297        let req_msg = self.new_request(request, self.versions.lookup_version::<R>());
298
299        // send request & save response
300        self.socket.send_and_receive(req_msg).await
301    }
302
303    /// send and do not wait for reply
304    #[instrument(level = "trace", skip(self, request))]
305    pub async fn send_async<R>(&self, request: R) -> Result<AsyncResponse<R>, SocketError>
306    where
307        R: Request + Send + Sync,
308    {
309        self.check_liveness()?;
310
311        let req_msg = self.new_request(request, self.lookup_version::<R>());
312
313        // send request & get a Future that resolves to response
314        self.socket.send_async(req_msg).await
315    }
316
317    /// look up version for the request
318    pub fn lookup_version<R>(&self) -> Option<Version>
319    where
320        R: Request,
321    {
322        self.versions.lookup_version::<R>()
323    }
324
325    /// send, wait for reply and retry if failed
326    #[instrument(level = "trace", skip(self, request))]
327    pub async fn send_receive_with_retry<R, I>(
328        &self,
329        request: R,
330        retries: I,
331    ) -> Result<R::Response, SocketError>
332    where
333        R: Request + Send + Sync + Clone,
334        I: IntoIterator<Item = Duration> + Debug + Send,
335    {
336        self.check_liveness()?;
337
338        let req_msg = self.new_request(request, self.versions.lookup_version::<R>());
339
340        // send request & retry it if result is Err
341        retry_if(
342            retries,
343            || self.socket.send_and_receive(req_msg.clone()),
344            is_retryable,
345        )
346        .await
347    }
348
349    /// create new request based on version
350    #[instrument(level = "trace", skip(self, request, version))]
351    pub fn new_request<R>(&self, request: R, version: Option<i16>) -> RequestMessage<R>
352    where
353        R: Request + Send,
354    {
355        let mut req_msg = RequestMessage::new_request(request);
356        req_msg
357            .get_mut_header()
358            .set_client_id(&self.config().client_id);
359
360        if let Some(ver) = version {
361            req_msg.get_mut_header().set_api_version(ver);
362        }
363        req_msg
364    }
365}
366
367impl SerialFrame for VersionedSerialSocket {
368    fn config(&self) -> &ClientConfig {
369        &self.config
370    }
371}
372
373fn is_retryable(err: &SocketError) -> bool {
374    use std::io::ErrorKind;
375    match err {
376        SocketError::Io { source, .. } => matches!(
377            source.kind(),
378            ErrorKind::AddrNotAvailable
379                | ErrorKind::ConnectionAborted
380                | ErrorKind::ConnectionRefused
381                | ErrorKind::ConnectionReset
382                | ErrorKind::NotConnected
383                | ErrorKind::Other
384                | ErrorKind::TimedOut
385                | ErrorKind::UnexpectedEof
386                | ErrorKind::Interrupted
387        ),
388
389        SocketError::SocketClosed | SocketError::SocketStale => false,
390    }
391}
392
393#[cfg(test)]
394mod test {
395    use fluvio_protocol::Decoder;
396    use fluvio_protocol::Encoder;
397    use fluvio_protocol::api::Request;
398    use fluvio_protocol::link::versions::ApiVersionKey;
399
400    use super::ApiVersionsResponse;
401    use super::Versions;
402
403    #[derive(Encoder, Decoder, Default, Debug)]
404    struct T1;
405
406    impl Request for T1 {
407        const API_KEY: u16 = 1000;
408        const MIN_API_VERSION: i16 = 6;
409        const MAX_API_VERSION: i16 = 9;
410
411        type Response = u8;
412    }
413
414    #[derive(Encoder, Decoder, Default, Debug)]
415    struct T2;
416
417    impl Request for T2 {
418        const API_KEY: u16 = 1000;
419        const MIN_API_VERSION: i16 = 2;
420        const MAX_API_VERSION: i16 = 3;
421
422        type Response = u8;
423    }
424
425    #[test]
426    fn test_version_lookup() {
427        let mut response = ApiVersionsResponse::default();
428
429        response.api_keys.push(ApiVersionKey {
430            api_key: 1000,
431            min_version: 5,
432            max_version: 10,
433        });
434
435        let versions = Versions::new(response);
436
437        // None if api_key not found
438        assert_eq!(versions.lookup_version::<T1>(), Some(9));
439        assert_eq!(versions.lookup_version::<T2>(), None);
440    }
441}