triton_distributed/component/
client.rs1use crate::pipeline::{
17 network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18 AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32enum MapState {
36 Empty(u64),
38
39 NonEmpty(u64, u64),
41
42 Finished,
44}
45
46enum EndpointEvent {
47 Put(String, i64),
48 Delete(String),
49}
50
51#[derive(Clone)]
52pub struct Client<T: Data, U: Data> {
53 endpoint: Endpoint,
54 router: PushRouter<T, U>,
55 watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
56 counter: Arc<AtomicU64>,
57}
58
59impl<T, U> Client<T, U>
60where
61 T: Data + Serialize,
62 U: Data + for<'de> Deserialize<'de>,
63{
64 pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
65 let router = AddressedPushRouter::new(
66 endpoint.component.drt.nats_client.client().clone(),
67 endpoint.component.drt.tcp_server().await?,
68 )?;
69
70 let prefix_watcher = endpoint
72 .component
73 .drt
74 .etcd_client
75 .kv_get_and_watch_prefix(endpoint.etcd_path())
76 .await?;
77
78 let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
79
80 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
81
82 let secondary = endpoint.component.drt.runtime.secondary().clone();
83
84 secondary.spawn(async move {
88 log::debug!("Starting endpoint watcher for prefix: {}", prefix);
89 let mut map = HashMap::new();
90
91 loop {
92 let kv_event = tokio::select! {
93 _ = watch_tx.closed() => {
94 log::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
95 break;
96 }
97 kv_event = kv_event_rx.recv() => {
98 match kv_event {
99 Some(kv_event) => kv_event,
100 None => {
101 log::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
102 break;
103 }
104 }
105 }
106 };
107
108 match kv_event {
109 WatchEvent::Put(kv) => {
110 let key = String::from_utf8(kv.key().to_vec());
111 let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
112 if let (Ok(key), Ok(val)) = (key, val) {
113 map.insert(key.clone(), val.lease_id);
114 } else {
115 log::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
116 break;
117 }
118 }
119 WatchEvent::Delete(kv) => {
120 match String::from_utf8(kv.key().to_vec()) {
121 Ok(key) => { map.remove(&key); }
122 Err(_) => {
123 log::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
124 break;
125 }
126 }
127 }
128 }
129
130 let endpoint_ids: Vec<i64> = map.values().cloned().collect();
131
132 if watch_tx.send(endpoint_ids).is_err() {
133 log::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
134 break;
135 }
136
137 }
138
139 log::debug!("Completed endpoint watcher for prefix: {}", prefix);
140 let _ = watch_tx.send(vec![]);
141 });
142
143 Ok(Client {
144 endpoint,
145 router,
146 watch_rx,
147 counter: Arc::new(AtomicU64::new(0)),
148 })
149 }
150
151 pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
152 &self.watch_rx
153 }
154
155 pub async fn wait_for_endpoints(&self) -> Result<()> {
157 let mut rx = self.watch_rx.clone();
158 loop {
160 if rx.borrow_and_update().is_empty() {
161 rx.changed().await?;
162 } else {
163 break;
164 }
165 }
166
167 Ok(())
168 }
169
170 pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
172 let counter = self.counter.fetch_add(1, Ordering::Relaxed);
173
174 let endpoint_id = {
175 let endpoints = self.watch_rx.borrow();
176 let count = endpoints.len();
177 if count == 0 {
178 return Err(error!(
179 "no endpoints found for endpoint {:?}",
180 self.endpoint.etcd_path()
181 ));
182 }
183 let offset = counter % count as u64;
184 endpoints[offset as usize]
185 };
186
187 let subject = self.endpoint.subject(endpoint_id);
188 let request = request.map(|req| AddressedRequest::new(req, subject));
189
190 self.router.generate(request).await
191 }
192
193 pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
195 let endpoint_id = {
196 let endpoints = self.watch_rx.borrow();
197 let count = endpoints.len();
198 if count == 0 {
199 return Err(error!(
200 "no endpoints found for endpoint {:?}",
201 self.endpoint.etcd_path()
202 ));
203 }
204 let counter = rand::thread_rng().gen::<u64>();
205 let offset = counter % count as u64;
206 endpoints[offset as usize]
207 };
208
209 let subject = self.endpoint.subject(endpoint_id);
210 let request = request.map(|req| AddressedRequest::new(req, subject));
211
212 self.router.generate(request).await
213 }
214
215 pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
217 let found = {
218 let endpoints = self.watch_rx.borrow();
219 endpoints.contains(&endpoint_id)
220 };
221
222 if !found {
223 return Err(error!(
224 "endpoint_id={} not found for endpoint {:?}",
225 endpoint_id,
226 self.endpoint.etcd_path()
227 ));
228 }
229
230 let subject = self.endpoint.subject(endpoint_id);
231 let request = request.map(|req| AddressedRequest::new(req, subject));
232
233 self.router.generate(request).await
234 }
235}
236
237#[async_trait]
238impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
239where
240 T: Data + Serialize,
241 U: Data + for<'de> Deserialize<'de>,
242{
243 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
244 self.random(request).await
245 }
246}