dynamo_runtime/component/
client.rs1use crate::pipeline::{
17 network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18 AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32enum MapState {
36 Empty(u64),
38
39 NonEmpty(u64, u64),
41
42 Finished,
44}
45
46enum EndpointEvent {
47 Put(String, i64),
48 Delete(String),
49}
50
51#[derive(Default, Debug, Clone, Copy)]
52pub enum RouterMode {
53 #[default]
54 Random,
55 RoundRobin,
56 Direct(i64),
61}
62
63#[derive(Clone)]
64pub struct Client<T: Data, U: Data> {
65 endpoint: Endpoint,
66 router: PushRouter<T, U>,
67 counter: Arc<AtomicU64>,
68 endpoints: EndpointSource,
69 router_mode: RouterMode,
70}
71
72#[derive(Clone, Debug)]
73enum EndpointSource {
74 Static,
75 Dynamic(tokio::sync::watch::Receiver<Vec<i64>>),
76}
77
78impl<T, U> Client<T, U>
79where
80 T: Data + Serialize,
81 U: Data + for<'de> Deserialize<'de>,
82{
83 pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
85 Ok(Client {
86 router: router(&endpoint).await?,
87 endpoint,
88 counter: Arc::new(AtomicU64::new(0)),
89 endpoints: EndpointSource::Static,
90 router_mode: Default::default(),
91 })
92 }
93
94 pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
96 let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
98 anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
99 };
100 let prefix_watcher = etcd_client
101 .kv_get_and_watch_prefix(endpoint.etcd_path())
102 .await?;
103
104 let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
105
106 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
107
108 let secondary = endpoint.component.drt.runtime.secondary().clone();
109
110 secondary.spawn(async move {
114 tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
115 let mut map = HashMap::new();
116
117 loop {
118 let kv_event = tokio::select! {
119 _ = watch_tx.closed() => {
120 tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
121 break;
122 }
123 kv_event = kv_event_rx.recv() => {
124 match kv_event {
125 Some(kv_event) => kv_event,
126 None => {
127 tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
128 break;
129 }
130 }
131 }
132 };
133
134 match kv_event {
135 WatchEvent::Put(kv) => {
136 let key = String::from_utf8(kv.key().to_vec());
137 let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
138 if let (Ok(key), Ok(val)) = (key, val) {
139 map.insert(key.clone(), val.lease_id);
140 } else {
141 tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
142 break;
143 }
144 }
145 WatchEvent::Delete(kv) => {
146 match String::from_utf8(kv.key().to_vec()) {
147 Ok(key) => { map.remove(&key); }
148 Err(_) => {
149 tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
150 break;
151 }
152 }
153 }
154 }
155
156 let endpoint_ids: Vec<i64> = map.values().cloned().collect();
157
158 if watch_tx.send(endpoint_ids).is_err() {
159 tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
160 break;
161 }
162
163 }
164
165 tracing::debug!("Completed endpoint watcher for prefix: {}", prefix);
166 let _ = watch_tx.send(vec![]);
167 });
168
169 Ok(Client {
170 router: router(&endpoint).await?,
171 endpoint,
172 counter: Arc::new(AtomicU64::new(0)),
173 endpoints: EndpointSource::Dynamic(watch_rx),
174 router_mode: Default::default(),
175 })
176 }
177
178 pub fn path(&self) -> String {
180 self.endpoint.path()
181 }
182
183 pub fn etcd_path(&self) -> String {
185 self.endpoint.etcd_path()
186 }
187
188 pub fn endpoint_ids(&self) -> Vec<i64> {
189 match &self.endpoints {
190 EndpointSource::Static => vec![0],
191 EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
192 }
193 }
194
195 pub fn set_router_mode(&mut self, mode: RouterMode) {
196 self.router_mode = mode
197 }
198
199 pub async fn wait_for_endpoints(&self) -> Result<()> {
201 if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
202 loop {
204 if rx.borrow_and_update().is_empty() {
205 rx.changed().await?;
206 } else {
207 break;
208 }
209 }
210 }
211 Ok(())
212 }
213
214 pub fn is_static(&self) -> bool {
216 matches!(self.endpoints, EndpointSource::Static)
217 }
218
219 pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
221 let counter = self.counter.fetch_add(1, Ordering::Relaxed);
222
223 let endpoint_id = {
224 let endpoints = self.endpoint_ids();
225 let count = endpoints.len();
226 if count == 0 {
227 return Err(error!(
228 "no endpoints found for endpoint {:?}",
229 self.endpoint.etcd_path()
230 ));
231 }
232 let offset = counter % count as u64;
233 endpoints[offset as usize]
234 };
235 tracing::trace!("round robin router selected {endpoint_id}");
236
237 let subject = self.endpoint.subject_to(endpoint_id);
238 let request = request.map(|req| AddressedRequest::new(req, subject));
239
240 self.router.generate(request).await
241 }
242
243 pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
245 let endpoint_id = {
246 let endpoints = self.endpoint_ids();
247 let count = endpoints.len();
248 if count == 0 {
249 return Err(error!(
250 "no endpoints found for endpoint {:?}",
251 self.endpoint.etcd_path()
252 ));
253 }
254 let counter = rand::rng().random::<u64>();
255 let offset = counter % count as u64;
256 endpoints[offset as usize]
257 };
258 tracing::trace!("random router selected {endpoint_id}");
259
260 let subject = self.endpoint.subject_to(endpoint_id);
261 let request = request.map(|req| AddressedRequest::new(req, subject));
262
263 self.router.generate(request).await
264 }
265
266 pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
268 let found = {
269 let endpoints = self.endpoint_ids();
270 endpoints.contains(&endpoint_id)
271 };
272
273 if !found {
274 return Err(error!(
275 "endpoint_id={} not found for endpoint {:?}",
276 endpoint_id,
277 self.endpoint.etcd_path()
278 ));
279 }
280
281 let subject = self.endpoint.subject_to(endpoint_id);
282 let request = request.map(|req| AddressedRequest::new(req, subject));
283
284 self.router.generate(request).await
285 }
286
287 pub async fn r#static(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
288 let subject = self.endpoint.subject();
289 tracing::debug!("static got subject: {subject}");
290 let request = request.map(|req| AddressedRequest::new(req, subject));
291 tracing::debug!("router generate");
292 self.router.generate(request).await
293 }
294}
295
296async fn router(endpoint: &Endpoint) -> Result<Arc<AddressedPushRouter>> {
297 AddressedPushRouter::new(
298 endpoint.component.drt.nats_client.client().clone(),
299 endpoint.component.drt.tcp_server().await?,
300 )
301}
302
303#[async_trait]
304impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
305where
306 T: Data + Serialize,
307 U: Data + for<'de> Deserialize<'de>,
308{
309 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
310 match &self.endpoints {
311 EndpointSource::Static => self.r#static(request).await,
312 EndpointSource::Dynamic(_) => match self.router_mode {
313 RouterMode::Random => self.random(request).await,
314 RouterMode::RoundRobin => self.round_robin(request).await,
315 RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
316 },
317 }
318 }
319}