decthings_api/client/
mod.rs

1mod error;
2
3#[cfg(feature = "events")]
4pub mod event;
5
6#[cfg(target_os = "espidf")]
7mod espidf_http_impl;
8#[cfg(target_os = "espidf")]
9use espidf_http_impl::*;
10
11mod parameter;
12mod protocol;
13
14#[cfg(not(target_os = "espidf"))]
15mod reqwest_http_impl;
16#[cfg(not(target_os = "espidf"))]
17use reqwest_http_impl::*;
18
19pub mod rpc;
20
21#[cfg(feature = "events")]
22mod websocket;
23
24use std::sync::Arc;
25use tokio::sync::RwLock;
26
27pub use ndarray;
28
29pub use error::{DecthingsClientError, DecthingsRpcError};
30pub use parameter::*;
31
32struct StateModification {
33    #[cfg(feature = "events")]
34    add_events: Vec<String>,
35    #[cfg(feature = "events")]
36    remove_events: Vec<String>,
37}
38
39impl StateModification {
40    fn empty() -> Self {
41        Self {
42            #[cfg(feature = "events")]
43            add_events: vec![],
44            #[cfg(feature = "events")]
45            remove_events: vec![],
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct DecthingsClientOptions {
52    #[cfg(feature = "events")]
53    /// Server address to use for WebSocket API. Defaults to "wss://api.decthings.com/v0/ws`
54    pub ws_server_address: String,
55
56    /// Server address to use for HTTP API. Defaults to `https://api.decthings.com/v0`
57    pub http_server_address: String,
58    /// Optional API key. Some methods require this to be set.
59    pub api_key: Option<String>,
60    /// Additional headers to add to each request.
61    pub extra_headers: http::HeaderMap<http::HeaderValue>,
62}
63
64impl std::default::Default for DecthingsClientOptions {
65    fn default() -> Self {
66        Self {
67            #[cfg(feature = "events")]
68            ws_server_address: "wss://api.decthings.com/v0/v0".to_string(),
69
70            http_server_address: "https://api.decthings.com/v0".to_string(),
71            api_key: None,
72            extra_headers: http::HeaderMap::new(),
73        }
74    }
75}
76
77/// The protocol to use for a RPC request.
78#[derive(Debug, Clone)]
79pub enum RpcProtocol {
80    /// Force use of HTTP.
81    Http,
82    #[cfg(feature = "events")]
83    /// Force use of WebSocket. If no WebSocket is connected, a new one will be created.
84    Ws,
85    #[cfg(feature = "events")]
86    /// Use WebSocket if one is connected, otherwise do not send the request.
87    WsIfAvailableOtherwiseNone,
88}
89
90#[derive(Clone)]
91pub(crate) struct DecthingsClientRpc {
92    #[cfg(feature = "events")]
93    ws_server_address: String,
94
95    http_server_address: String,
96    api_key: Arc<RwLock<Option<Arc<str>>>>,
97    extra_headers: Arc<http::HeaderMap<http::HeaderValue>>,
98
99    #[cfg(feature = "events")]
100    event_listeners: Arc<event::EventListeners>,
101
102    #[cfg(feature = "events")]
103    ws: Arc<RwLock<(u64, Option<(u64, Arc<websocket::DecthingsClientWebsocket>)>)>>,
104
105    http: HttpImpl,
106}
107
108impl DecthingsClientRpc {
109    fn new(options: DecthingsClientOptions) -> Self {
110        Self {
111            #[cfg(feature = "events")]
112            ws_server_address: options.ws_server_address,
113
114            http_server_address: options.http_server_address,
115            api_key: Arc::new(RwLock::new(options.api_key.map(Arc::from))),
116            extra_headers: Arc::new(options.extra_headers),
117
118            #[cfg(feature = "events")]
119            event_listeners: Arc::new(event::EventListeners::new()),
120
121            #[cfg(feature = "events")]
122            ws: Arc::new(RwLock::new((0, None))),
123
124            http: HttpImpl::default(),
125        }
126    }
127
128    async fn set_api_key(&self, api_key: String) {
129        let mut locked = self.api_key.write().await;
130        *locked = Some(Arc::from(api_key));
131    }
132
133    #[cfg(feature = "events")]
134    async fn on_event(
135        &self,
136        cb: impl Fn(&event::DecthingsEvent) + Send + Sync + 'static,
137    ) -> event::EventListenerDisposer {
138        self.event_listeners.add(cb).await
139    }
140
141    #[cfg(feature = "events")]
142    async fn maybe_get_socket(&self) -> Option<Arc<websocket::DecthingsClientWebsocket>> {
143        let ws = self.ws.read().await;
144        ws.1.as_ref().map(|inner_ws| Arc::clone(&inner_ws.1))
145    }
146
147    #[cfg(feature = "events")]
148    async fn get_or_create_socket(&self) -> Arc<websocket::DecthingsClientWebsocket> {
149        let ws = self.ws.read().await;
150        if let Some(inner_ws) = ws.1.as_ref() {
151            return Arc::clone(&inner_ws.1);
152        }
153        drop(ws);
154        let mut ws_mut = self.ws.write().await;
155        if let Some(inner_ws) = ws_mut.1.as_ref() {
156            return Arc::clone(&inner_ws.1);
157        }
158        let ws_clone = Arc::clone(&self.ws);
159        let ws_clone2 = Arc::clone(&self.ws);
160        let id = ws_mut.0;
161        ws_mut.0 += 1;
162
163        let event_listeners_clone = Arc::clone(&self.event_listeners);
164        let sock = Arc::new(websocket::DecthingsClientWebsocket::connect(
165            &self.extra_headers,
166            move || async move {
167                let mut ws_clone_lock = ws_clone.write().await;
168                if let Some(inner_ws_clone) = ws_clone_lock.1.as_mut() {
169                    if inner_ws_clone.0 == id {
170                        ws_clone_lock.1 = None;
171                    }
172                }
173                event_listeners_clone
174                    .call(&event::DecthingsEvent::SubscriptionsRemoved)
175                    .await;
176            },
177            move || {
178                let ws_clone3 = Arc::clone(&ws_clone2);
179                async move {
180                    let mut ws_clone_lock = ws_clone3.write().await;
181                    if let Some(inner_ws_clone) = ws_clone_lock.1.as_mut() {
182                        if inner_ws_clone.1.is_unused().await {
183                            ws_clone_lock.1 = None;
184                        }
185                    }
186                }
187            },
188            &self.ws_server_address,
189            Arc::downgrade(&self.event_listeners),
190        ));
191        let _ = ws_mut.1.insert((id, Arc::clone(&sock)));
192        sock
193    }
194
195    /// Call an RPC method on the server.
196    ///
197    /// Returns false if the request is not sent (and on_result is not called). This happens if
198    /// mode is RpcProtocol::WsIfAvailableOtherwiseNone and no WebSocket is connected.
199    async fn raw_method_call<
200        P: serde::Serialize,
201        F: FnOnce(
202                Result<(bytes::Bytes, Vec<bytes::Bytes>), DecthingsClientError>,
203            ) -> StateModification
204            + Send
205            + 'static,
206        D: AsRef<[u8]>,
207    >(
208        &self,
209        api: &str,
210        method: &str,
211        params: P,
212        data: impl AsRef<[D]>,
213        #[allow(unused)] mode: RpcProtocol,
214        on_result: F,
215    ) -> bool {
216        #[cfg(feature = "events")]
217        {
218            let maybe_ws = match mode {
219                RpcProtocol::Http => None,
220                RpcProtocol::Ws => Some(self.get_or_create_socket().await),
221                RpcProtocol::WsIfAvailableOtherwiseNone => {
222                    if let Some(ws) = self.maybe_get_socket().await {
223                        Some(ws)
224                    } else {
225                        return false;
226                    }
227                }
228            };
229
230            if let Some(ws) = maybe_ws {
231                // Send over WebSocket
232                let api_key = {
233                    let locked = self.api_key.read().await;
234                    locked.clone()
235                };
236                ws.call(
237                    api,
238                    method,
239                    params,
240                    api_key.as_deref(),
241                    data,
242                    Box::new(move |x| on_result(x.map_err(|e| e.into()))),
243                )
244                .await;
245                return true;
246            }
247        }
248
249        // Send over HTTP
250        let res = async {
251            let body = protocol::serialize_for_http(params, data.as_ref());
252            drop(data);
253
254            let api_key_locked = self.api_key.read().await;
255            let api_key = api_key_locked.clone();
256            drop(api_key_locked);
257
258            let response_body = self
259                .http
260                .get(
261                    &self.http_server_address,
262                    api,
263                    method,
264                    body,
265                    api_key,
266                    self.extra_headers.clone(),
267                )
268                .await?;
269
270            let deserialized = protocol::deserialize_for_http(response_body)
271                .map_err(|_| DecthingsClientError::InvalidMessage)?;
272            Ok(deserialized)
273        }
274        .await;
275
276        on_result(res);
277        true
278    }
279}
280
281pub struct DecthingsClient {
282    rpc: DecthingsClientRpc,
283    pub dataset: rpc::dataset::DatasetRpc,
284    pub debug: rpc::debug::DebugRpc,
285    pub fs: rpc::fs::FsRpc,
286    pub image: rpc::image::ImageRpc,
287    #[cfg(feature = "events")]
288    pub language: rpc::language::LanguageRpc,
289    pub model: rpc::model::ModelRpc,
290    pub persistent_launcher: rpc::persistent_launcher::PersistentLauncherRpc,
291    pub spawned: rpc::spawned::SpawnedRpc,
292    pub terminal: rpc::terminal::TerminalRpc,
293}
294
295impl Default for DecthingsClient {
296    fn default() -> Self {
297        Self::new(Default::default())
298    }
299}
300
301impl DecthingsClient {
302    pub fn new(options: DecthingsClientOptions) -> Self {
303        let rpc = DecthingsClientRpc::new(options);
304        Self {
305            dataset: rpc::dataset::DatasetRpc::new(rpc.clone()),
306            debug: rpc::debug::DebugRpc::new(rpc.clone()),
307            fs: rpc::fs::FsRpc::new(rpc.clone()),
308            image: rpc::image::ImageRpc::new(rpc.clone()),
309            #[cfg(feature = "events")]
310            language: rpc::language::LanguageRpc::new(rpc.clone()),
311            model: rpc::model::ModelRpc::new(rpc.clone()),
312            persistent_launcher: rpc::persistent_launcher::PersistentLauncherRpc::new(rpc.clone()),
313            spawned: rpc::spawned::SpawnedRpc::new(rpc.clone()),
314            terminal: rpc::terminal::TerminalRpc::new(rpc.clone()),
315            rpc,
316        }
317    }
318
319    /// Call an RPC method on the server.
320    ///
321    /// You most likely want to use the helper classes (client.model, client.dataset, etc.) instead.
322    pub async fn raw_method_call<P: serde::Serialize, D: AsRef<[u8]>>(
323        &self,
324        api: &str,
325        method: &str,
326        params: P,
327        data: impl AsRef<[D]>,
328    ) -> Result<(bytes::Bytes, Vec<bytes::Bytes>), DecthingsClientError> {
329        let (tx, rx) = tokio::sync::oneshot::channel();
330        self.rpc
331            .raw_method_call(api, method, params, data, RpcProtocol::Http, |res| {
332                tx.send(res).ok();
333                StateModification {
334                    add_events: vec![],
335                    remove_events: vec![],
336                }
337            })
338            .await;
339        rx.await.unwrap()
340    }
341
342    #[cfg(feature = "events")]
343    pub async fn on_event(
344        &self,
345        cb: impl Fn(&event::DecthingsEvent) + Send + Sync + 'static,
346    ) -> event::EventListenerDisposer {
347        self.rpc.on_event(Box::new(cb)).await
348    }
349
350    pub async fn set_api_key(&self, api_key: String) {
351        self.rpc.set_api_key(api_key).await;
352    }
353}