triton_distributed/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::pipeline::{
17    network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18    AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23    atomic::{AtomicU64, Ordering},
24    Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32/// Each state will be have a nonce associated with it
33/// The state will be emitted in a watch channel, so we can observe the
34/// critical state transitions.
35enum MapState {
36    /// The map is empty; value = nonce
37    Empty(u64),
38
39    /// The map is not-empty; values are (nonce, count)
40    NonEmpty(u64, u64),
41
42    /// The watcher has finished, no more events will be emitted
43    Finished,
44}
45
46enum EndpointEvent {
47    Put(String, i64),
48    Delete(String),
49}
50
51#[derive(Clone)]
52pub struct Client<T: Data, U: Data> {
53    endpoint: Endpoint,
54    router: PushRouter<T, U>,
55    watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
56    counter: Arc<AtomicU64>,
57}
58
59impl<T, U> Client<T, U>
60where
61    T: Data + Serialize,
62    U: Data + for<'de> Deserialize<'de>,
63{
64    pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
65        let router = AddressedPushRouter::new(
66            endpoint.component.drt.nats_client.client().clone(),
67            endpoint.component.drt.tcp_server().await?,
68        )?;
69
70        // create live endpoint watcher
71        let prefix_watcher = endpoint
72            .component
73            .drt
74            .etcd_client
75            .kv_get_and_watch_prefix(endpoint.etcd_path())
76            .await?;
77
78        let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
79
80        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
81
82        let secondary = endpoint.component.drt.runtime.secondary().clone();
83
84        // this task should be included in the registry
85        // currently this is created once per client, but this object/task should only be instantiated
86        // once per worker/instance
87        secondary.spawn(async move {
88            log::debug!("Starting endpoint watcher for prefix: {}", prefix);
89            let mut map = HashMap::new();
90
91            loop {
92                let kv_event = tokio::select! {
93                    _ = watch_tx.closed() => {
94                        log::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
95                        break;
96                    }
97                    kv_event = kv_event_rx.recv() => {
98                        match kv_event {
99                            Some(kv_event) => kv_event,
100                            None => {
101                                log::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
102                                break;
103                            }
104                        }
105                    }
106                };
107
108                match kv_event {
109                    WatchEvent::Put(kv) => {
110                        let key = String::from_utf8(kv.key().to_vec());
111                        let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
112                        if let (Ok(key), Ok(val)) = (key, val) {
113                            map.insert(key.clone(), val.lease_id);
114                        } else {
115                            log::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
116                            break;
117                        }
118                    }
119                    WatchEvent::Delete(kv) => {
120                        match String::from_utf8(kv.key().to_vec()) {
121                            Ok(key) => { map.remove(&key); }
122                            Err(_) => {
123                                log::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
124                                break;
125                            }
126                        }
127                    }
128                }
129
130                let endpoint_ids: Vec<i64> = map.values().cloned().collect();
131
132                if watch_tx.send(endpoint_ids).is_err() {
133                    log::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
134                    break;
135                }
136
137            }
138
139            log::debug!("Completed endpoint watcher for prefix: {}", prefix);
140            let _ = watch_tx.send(vec![]);
141        });
142
143        Ok(Client {
144            endpoint,
145            router,
146            watch_rx,
147            counter: Arc::new(AtomicU64::new(0)),
148        })
149    }
150
151    pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
152        &self.watch_rx
153    }
154
155    /// Wait for at least one [`Endpoint`] to be available
156    pub async fn wait_for_endpoints(&self) -> Result<()> {
157        let mut rx = self.watch_rx.clone();
158        // wait for there to be 1 or more endpoints
159        loop {
160            if rx.borrow_and_update().is_empty() {
161                rx.changed().await?;
162            } else {
163                break;
164            }
165        }
166
167        Ok(())
168    }
169
170    /// Issue a request to the next available endpoint in a round-robin fashion
171    pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
172        let counter = self.counter.fetch_add(1, Ordering::Relaxed);
173
174        let endpoint_id = {
175            let endpoints = self.watch_rx.borrow();
176            let count = endpoints.len();
177            if count == 0 {
178                return Err(error!(
179                    "no endpoints found for endpoint {:?}",
180                    self.endpoint.etcd_path()
181                ));
182            }
183            let offset = counter % count as u64;
184            endpoints[offset as usize]
185        };
186
187        let subject = self.endpoint.subject(endpoint_id);
188        let request = request.map(|req| AddressedRequest::new(req, subject));
189
190        self.router.generate(request).await
191    }
192
193    /// Issue a request to a random endpoint
194    pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
195        let endpoint_id = {
196            let endpoints = self.watch_rx.borrow();
197            let count = endpoints.len();
198            if count == 0 {
199                return Err(error!(
200                    "no endpoints found for endpoint {:?}",
201                    self.endpoint.etcd_path()
202                ));
203            }
204            let counter = rand::thread_rng().gen::<u64>();
205            let offset = counter % count as u64;
206            endpoints[offset as usize]
207        };
208
209        let subject = self.endpoint.subject(endpoint_id);
210        let request = request.map(|req| AddressedRequest::new(req, subject));
211
212        self.router.generate(request).await
213    }
214
215    /// Issue a request to a specific endpoint
216    pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
217        let found = {
218            let endpoints = self.watch_rx.borrow();
219            endpoints.contains(&endpoint_id)
220        };
221
222        if !found {
223            return Err(error!(
224                "endpoint_id={} not found for endpoint {:?}",
225                endpoint_id,
226                self.endpoint.etcd_path()
227            ));
228        }
229
230        let subject = self.endpoint.subject(endpoint_id);
231        let request = request.map(|req| AddressedRequest::new(req, subject));
232
233        self.router.generate(request).await
234    }
235}
236
237#[async_trait]
238impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
239where
240    T: Data + Serialize,
241    U: Data + for<'de> Deserialize<'de>,
242{
243    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
244        self.random(request).await
245    }
246}