fluvio/
spu.rs

1use std::sync::Arc;
2use std::collections::HashMap;
3use anyhow::Result;
4
5use fluvio_sc_schema::partition::PartitionSpec;
6use fluvio_sc_schema::topic::TopicSpec;
7use tracing::{debug, trace, instrument};
8use async_lock::Mutex;
9use async_trait::async_trait;
10
11use fluvio_protocol::record::ReplicaKey;
12use fluvio_protocol::api::Request;
13use fluvio_types::SpuId;
14use fluvio_socket::{
15    AsyncResponse, ClientConfig, MultiplexerSocket, SocketError, StreamSocket,
16    VersionedSerialSocket,
17};
18use crate::FluvioError;
19use crate::sync::{MetadataStores, StoreContext};
20
21/// used for connecting to spu
22#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
23#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
24pub trait SpuDirectory {
25    /// Create request/response socket to SPU for a replica
26    ///
27    /// All sockets to same SPU use a single TCP connection.
28    /// First this looks up SPU address in SPU metadata and try to see if there is an existing TCP connection.
29    /// If not, it will create a new connection and creates socket to it
30    async fn create_serial_socket(
31        &self,
32        replica: &ReplicaKey,
33    ) -> Result<VersionedSerialSocket, FluvioError>;
34
35    /// create stream to leader replica
36    async fn create_stream_with_version<R: Request>(
37        &self,
38        replica: &ReplicaKey,
39        request: R,
40        version: i16,
41    ) -> Result<AsyncResponse<R>, FluvioError>
42    where
43        R: Sync + Send;
44}
45
46/// connection pool to spu
47#[derive(Clone)]
48pub struct SpuSocketPool {
49    config: Arc<ClientConfig>,
50    pub(crate) metadata: MetadataStores,
51    spu_clients: Arc<Mutex<HashMap<SpuId, StreamSocket>>>,
52}
53
54impl Drop for SpuSocketPool {
55    fn drop(&mut self) {
56        trace!("dropping spu pool");
57        self.shutdown();
58    }
59}
60
61#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
62#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
63pub trait SpuPool {
64    /// start synchronize based on pool
65    fn start(config: Arc<ClientConfig>, metadata: MetadataStores) -> Result<Self, SocketError>
66    where
67        Self: std::marker::Sized;
68
69    /// create new spu socket
70    async fn connect_to_leader(&self, leader: SpuId) -> Result<StreamSocket, FluvioError>;
71
72    async fn create_serial_socket_from_leader(
73        &self,
74        leader_id: SpuId,
75    ) -> Result<VersionedSerialSocket, FluvioError>;
76
77    async fn topic_exists(&self, topic: String) -> Result<bool, FluvioError>;
78
79    fn shutdown(&mut self);
80
81    fn topics(&self) -> &StoreContext<TopicSpec>;
82
83    fn partitions(&self) -> &StoreContext<PartitionSpec>;
84}
85
86#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
87#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
88impl SpuPool for SpuSocketPool {
89    /// start synchronize based on pool
90    fn start(config: Arc<ClientConfig>, metadata: MetadataStores) -> Result<Self, SocketError> {
91        debug!("starting spu pool");
92        Ok(Self {
93            metadata,
94            config,
95            spu_clients: Arc::new(Mutex::new(HashMap::new())),
96        })
97    }
98
99    /// create new spu socket
100    #[instrument(skip(self))]
101    async fn connect_to_leader(&self, leader: SpuId) -> Result<StreamSocket, FluvioError> {
102        let spu = self.metadata.spus().look_up_by_id(leader).await?;
103
104        let mut client_config = self.config.with_prefix_sni_domain(spu.key());
105
106        let spu_addr = match spu.spec.public_endpoint_local {
107            Some(local) if self.config.use_spu_local_address() => {
108                let host = local.host;
109                let port = local.port;
110                format!("{host}:{port}")
111            }
112            _ => spu.spec.public_endpoint.addr(),
113        };
114
115        debug!(leader = spu.spec.id,addr = %spu_addr,"try connecting to spu");
116        client_config.set_addr(spu_addr);
117        let versioned_socket = client_config.connect().await?;
118        let (socket, config, versions) = versioned_socket.split();
119        Ok(StreamSocket::new(
120            config,
121            MultiplexerSocket::shared(socket),
122            versions,
123        ))
124    }
125
126    #[instrument(skip(self))]
127    async fn create_serial_socket_from_leader(
128        &self,
129        leader_id: SpuId,
130    ) -> Result<VersionedSerialSocket, FluvioError> {
131        // check if already have existing connection to same SPU
132        let mut client_lock = self.spu_clients.lock().await;
133
134        if let Some(spu_socket) = client_lock.get_mut(&leader_id) {
135            if !spu_socket.is_stale() {
136                return Ok(spu_socket.create_serial_socket().await);
137            } else {
138                client_lock.remove(&leader_id);
139            }
140        }
141
142        let mut spu_socket = self.connect_to_leader(leader_id).await?;
143        let serial_socket = spu_socket.create_serial_socket().await;
144        client_lock.insert(leader_id, spu_socket);
145
146        Ok(serial_socket)
147    }
148
149    async fn topic_exists(&self, topic: String) -> Result<bool, FluvioError> {
150        let replica = ReplicaKey::new(topic, 0u32);
151        Ok(self.partitions().lookup_by_key(&replica).await?.is_some())
152    }
153
154    fn shutdown(&mut self) {
155        self.metadata.shutdown();
156    }
157
158    fn topics(&self) -> &StoreContext<TopicSpec> {
159        self.metadata.topics()
160    }
161
162    fn partitions(&self) -> &StoreContext<PartitionSpec> {
163        self.metadata.partitions()
164    }
165}
166
167#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
168#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
169impl SpuDirectory for SpuSocketPool {
170    /// Create request/response socket to SPU for a replica
171    ///
172    /// All sockets to same SPU use a single TCP connection.
173    /// First this looks up SPU address in SPU metadata and try to see if there is an existing TCP connection.
174    /// If not, it will create a new connection and creates socket to it
175    #[instrument(skip(self, replica))]
176    async fn create_serial_socket(
177        &self,
178        replica: &ReplicaKey,
179    ) -> Result<VersionedSerialSocket, FluvioError> {
180        let partition_search = self.metadata.partitions().lookup_by_key(replica).await?;
181        let partition = if let Some(partition) = partition_search {
182            partition
183        } else {
184            return Err(FluvioError::PartitionNotFound(
185                replica.topic.to_string(),
186                replica.partition,
187            ));
188        };
189
190        let leader_id = partition.spec.leader;
191        let socket = self.create_serial_socket_from_leader(leader_id).await?;
192        Ok(socket)
193    }
194
195    #[instrument(skip(self, replica, request, version))]
196    async fn create_stream_with_version<R: Request>(
197        &self,
198        replica: &ReplicaKey,
199        request: R,
200        version: i16,
201    ) -> Result<AsyncResponse<R>, FluvioError>
202    where
203        R: Sync + Send,
204    {
205        let partition_search = self.metadata.partitions().lookup_by_key(replica).await?;
206
207        let partition = if let Some(partition) = partition_search {
208            partition
209        } else {
210            return Err(FluvioError::PartitionNotFound(
211                replica.topic.to_owned(),
212                replica.partition,
213            ));
214        };
215
216        let leader_id = partition.spec.leader;
217
218        // check if already have existing leader or create new connection to leader
219        let mut client_lock = self.spu_clients.lock().await;
220
221        if let Some(spu_socket) = client_lock.get_mut(&leader_id) {
222            return spu_socket
223                .create_stream_with_version(request, version)
224                .await
225                .map_err(|err| err.into());
226        }
227
228        let mut spu_socket = self.connect_to_leader(leader_id).await?;
229        let stream = spu_socket
230            .create_stream_with_version(request, version)
231            .await?;
232        client_lock.insert(leader_id, spu_socket);
233
234        Ok(stream)
235    }
236}