1use std::sync::Arc;
2use std::collections::HashMap;
3use anyhow::Result;
4
5use fluvio_sc_schema::partition::PartitionSpec;
6use fluvio_sc_schema::topic::TopicSpec;
7use tracing::{debug, trace, instrument};
8use async_lock::Mutex;
9use async_trait::async_trait;
10
11use fluvio_protocol::record::ReplicaKey;
12use fluvio_protocol::api::Request;
13use fluvio_types::SpuId;
14use fluvio_socket::{
15 AsyncResponse, ClientConfig, MultiplexerSocket, SocketError, StreamSocket,
16 VersionedSerialSocket,
17};
18use crate::FluvioError;
19use crate::sync::{MetadataStores, StoreContext};
20
21#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
23#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
24pub trait SpuDirectory {
25 async fn create_serial_socket(
31 &self,
32 replica: &ReplicaKey,
33 ) -> Result<VersionedSerialSocket, FluvioError>;
34
35 async fn create_stream_with_version<R: Request>(
37 &self,
38 replica: &ReplicaKey,
39 request: R,
40 version: i16,
41 ) -> Result<AsyncResponse<R>, FluvioError>
42 where
43 R: Sync + Send;
44}
45
46#[derive(Clone)]
48pub struct SpuSocketPool {
49 config: Arc<ClientConfig>,
50 pub(crate) metadata: MetadataStores,
51 spu_clients: Arc<Mutex<HashMap<SpuId, StreamSocket>>>,
52}
53
54impl Drop for SpuSocketPool {
55 fn drop(&mut self) {
56 trace!("dropping spu pool");
57 self.shutdown();
58 }
59}
60
61#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
62#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
63pub trait SpuPool {
64 fn start(config: Arc<ClientConfig>, metadata: MetadataStores) -> Result<Self, SocketError>
66 where
67 Self: std::marker::Sized;
68
69 async fn connect_to_leader(&self, leader: SpuId) -> Result<StreamSocket, FluvioError>;
71
72 async fn create_serial_socket_from_leader(
73 &self,
74 leader_id: SpuId,
75 ) -> Result<VersionedSerialSocket, FluvioError>;
76
77 async fn topic_exists(&self, topic: String) -> Result<bool, FluvioError>;
78
79 fn shutdown(&mut self);
80
81 fn topics(&self) -> &StoreContext<TopicSpec>;
82
83 fn partitions(&self) -> &StoreContext<PartitionSpec>;
84}
85
86#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
87#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
88impl SpuPool for SpuSocketPool {
89 fn start(config: Arc<ClientConfig>, metadata: MetadataStores) -> Result<Self, SocketError> {
91 debug!("starting spu pool");
92 Ok(Self {
93 metadata,
94 config,
95 spu_clients: Arc::new(Mutex::new(HashMap::new())),
96 })
97 }
98
99 #[instrument(skip(self))]
101 async fn connect_to_leader(&self, leader: SpuId) -> Result<StreamSocket, FluvioError> {
102 let spu = self.metadata.spus().look_up_by_id(leader).await?;
103
104 let mut client_config = self.config.with_prefix_sni_domain(spu.key());
105
106 let spu_addr = match spu.spec.public_endpoint_local {
107 Some(local) if self.config.use_spu_local_address() => {
108 let host = local.host;
109 let port = local.port;
110 format!("{host}:{port}")
111 }
112 _ => spu.spec.public_endpoint.addr(),
113 };
114
115 debug!(leader = spu.spec.id,addr = %spu_addr,"try connecting to spu");
116 client_config.set_addr(spu_addr);
117 let versioned_socket = client_config.connect().await?;
118 let (socket, config, versions) = versioned_socket.split();
119 Ok(StreamSocket::new(
120 config,
121 MultiplexerSocket::shared(socket),
122 versions,
123 ))
124 }
125
126 #[instrument(skip(self))]
127 async fn create_serial_socket_from_leader(
128 &self,
129 leader_id: SpuId,
130 ) -> Result<VersionedSerialSocket, FluvioError> {
131 let mut client_lock = self.spu_clients.lock().await;
133
134 if let Some(spu_socket) = client_lock.get_mut(&leader_id) {
135 if !spu_socket.is_stale() {
136 return Ok(spu_socket.create_serial_socket().await);
137 } else {
138 client_lock.remove(&leader_id);
139 }
140 }
141
142 let mut spu_socket = self.connect_to_leader(leader_id).await?;
143 let serial_socket = spu_socket.create_serial_socket().await;
144 client_lock.insert(leader_id, spu_socket);
145
146 Ok(serial_socket)
147 }
148
149 async fn topic_exists(&self, topic: String) -> Result<bool, FluvioError> {
150 let replica = ReplicaKey::new(topic, 0u32);
151 Ok(self.partitions().lookup_by_key(&replica).await?.is_some())
152 }
153
154 fn shutdown(&mut self) {
155 self.metadata.shutdown();
156 }
157
158 fn topics(&self) -> &StoreContext<TopicSpec> {
159 self.metadata.topics()
160 }
161
162 fn partitions(&self) -> &StoreContext<PartitionSpec> {
163 self.metadata.partitions()
164 }
165}
166
167#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
168#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
169impl SpuDirectory for SpuSocketPool {
170 #[instrument(skip(self, replica))]
176 async fn create_serial_socket(
177 &self,
178 replica: &ReplicaKey,
179 ) -> Result<VersionedSerialSocket, FluvioError> {
180 let partition_search = self.metadata.partitions().lookup_by_key(replica).await?;
181 let partition = if let Some(partition) = partition_search {
182 partition
183 } else {
184 return Err(FluvioError::PartitionNotFound(
185 replica.topic.to_string(),
186 replica.partition,
187 ));
188 };
189
190 let leader_id = partition.spec.leader;
191 let socket = self.create_serial_socket_from_leader(leader_id).await?;
192 Ok(socket)
193 }
194
195 #[instrument(skip(self, replica, request, version))]
196 async fn create_stream_with_version<R: Request>(
197 &self,
198 replica: &ReplicaKey,
199 request: R,
200 version: i16,
201 ) -> Result<AsyncResponse<R>, FluvioError>
202 where
203 R: Sync + Send,
204 {
205 let partition_search = self.metadata.partitions().lookup_by_key(replica).await?;
206
207 let partition = if let Some(partition) = partition_search {
208 partition
209 } else {
210 return Err(FluvioError::PartitionNotFound(
211 replica.topic.to_owned(),
212 replica.partition,
213 ));
214 };
215
216 let leader_id = partition.spec.leader;
217
218 let mut client_lock = self.spu_clients.lock().await;
220
221 if let Some(spu_socket) = client_lock.get_mut(&leader_id) {
222 return spu_socket
223 .create_stream_with_version(request, version)
224 .await
225 .map_err(|err| err.into());
226 }
227
228 let mut spu_socket = self.connect_to_leader(leader_id).await?;
229 let stream = spu_socket
230 .create_stream_with_version(request, version)
231 .await?;
232 client_lock.insert(leader_id, spu_socket);
233
234 Ok(stream)
235 }
236}