dynamo_runtime/component/
client.rs1use crate::pipeline::{
17 network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18 AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32enum MapState {
36 Empty(u64),
38
39 NonEmpty(u64, u64),
41
42 Finished,
44}
45
46enum EndpointEvent {
47 Put(String, i64),
48 Delete(String),
49}
50
51#[derive(Clone)]
52pub struct Client<T: Data, U: Data> {
53 endpoint: Endpoint,
54 router: PushRouter<T, U>,
55 watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
56 counter: Arc<AtomicU64>,
57}
58
59impl<T, U> Client<T, U>
60where
61 T: Data + Serialize,
62 U: Data + for<'de> Deserialize<'de>,
63{
64 pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
65 let router = AddressedPushRouter::new(
66 endpoint.component.drt.nats_client.client().clone(),
67 endpoint.component.drt.tcp_server().await?,
68 )?;
69
70 let prefix_watcher = endpoint
72 .component
73 .drt
74 .etcd_client
75 .kv_get_and_watch_prefix(endpoint.etcd_path())
76 .await?;
77
78 let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
79
80 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
81
82 let secondary = endpoint.component.drt.runtime.secondary().clone();
83
84 secondary.spawn(async move {
88 tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
89 let mut map = HashMap::new();
90
91 loop {
92 let kv_event = tokio::select! {
93 _ = watch_tx.closed() => {
94 tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
95 break;
96 }
97 kv_event = kv_event_rx.recv() => {
98 match kv_event {
99 Some(kv_event) => kv_event,
100 None => {
101 tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
102 break;
103 }
104 }
105 }
106 };
107
108 match kv_event {
109 WatchEvent::Put(kv) => {
110 let key = String::from_utf8(kv.key().to_vec());
111 let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
112 if let (Ok(key), Ok(val)) = (key, val) {
113 map.insert(key.clone(), val.lease_id);
114 } else {
115 tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
116 break;
117 }
118 }
119 WatchEvent::Delete(kv) => {
120 match String::from_utf8(kv.key().to_vec()) {
121 Ok(key) => { map.remove(&key); }
122 Err(_) => {
123 tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
124 break;
125 }
126 }
127 }
128 }
129
130 let endpoint_ids: Vec<i64> = map.values().cloned().collect();
131
132 if watch_tx.send(endpoint_ids).is_err() {
133 tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
134 break;
135 }
136
137 }
138
139 tracing::debug!("Completed endpoint watcher for prefix: {}", prefix);
140 let _ = watch_tx.send(vec![]);
141 });
142
143 Ok(Client {
144 endpoint,
145 router,
146 watch_rx,
147 counter: Arc::new(AtomicU64::new(0)),
148 })
149 }
150
151 pub fn path(&self) -> String {
153 self.endpoint.path()
154 }
155
156 pub fn etcd_path(&self) -> String {
158 self.endpoint.etcd_path()
159 }
160
161 pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
162 &self.watch_rx
163 }
164
165 pub async fn wait_for_endpoints(&self) -> Result<()> {
167 let mut rx = self.watch_rx.clone();
168 loop {
170 if rx.borrow_and_update().is_empty() {
171 rx.changed().await?;
172 } else {
173 break;
174 }
175 }
176
177 Ok(())
178 }
179
180 pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
182 let counter = self.counter.fetch_add(1, Ordering::Relaxed);
183
184 let endpoint_id = {
185 let endpoints = self.watch_rx.borrow();
186 let count = endpoints.len();
187 if count == 0 {
188 return Err(error!(
189 "no endpoints found for endpoint {:?}",
190 self.endpoint.etcd_path()
191 ));
192 }
193 let offset = counter % count as u64;
194 endpoints[offset as usize]
195 };
196
197 let subject = self.endpoint.subject_to(endpoint_id);
198 let request = request.map(|req| AddressedRequest::new(req, subject));
199
200 self.router.generate(request).await
201 }
202
203 pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
205 let endpoint_id = {
206 let endpoints = self.watch_rx.borrow();
207 let count = endpoints.len();
208 if count == 0 {
209 return Err(error!(
210 "no endpoints found for endpoint {:?}",
211 self.endpoint.etcd_path()
212 ));
213 }
214 let counter = rand::thread_rng().gen::<u64>();
215 let offset = counter % count as u64;
216 endpoints[offset as usize]
217 };
218
219 let subject = self.endpoint.subject_to(endpoint_id);
220 let request = request.map(|req| AddressedRequest::new(req, subject));
221
222 self.router.generate(request).await
223 }
224
225 pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
227 let found = {
228 let endpoints = self.watch_rx.borrow();
229 endpoints.contains(&endpoint_id)
230 };
231
232 if !found {
233 return Err(error!(
234 "endpoint_id={} not found for endpoint {:?}",
235 endpoint_id,
236 self.endpoint.etcd_path()
237 ));
238 }
239
240 let subject = self.endpoint.subject_to(endpoint_id);
241 let request = request.map(|req| AddressedRequest::new(req, subject));
242
243 self.router.generate(request).await
244 }
245}
246
247#[async_trait]
248impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
249where
250 T: Data + Serialize,
251 U: Data + for<'de> Deserialize<'de>,
252{
253 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
254 self.random(request).await
255 }
256}