Skip to main content

trellis_client/
operations.rs

1use std::future::Future;
2use std::marker::PhantomData;
3
4use futures_util::stream::{self, BoxStream};
5use futures_util::StreamExt;
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::{json, Value};
8
9use crate::TrellisClientError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "lowercase")]
13pub enum OperationState {
14    Pending,
15    Running,
16    Completed,
17    Failed,
18    Cancelled,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct OperationRefData {
23    pub id: String,
24    pub service: String,
25    pub operation: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29#[serde(rename_all = "camelCase")]
30pub struct OperationSnapshot<TProgress = Value, TOutput = Value> {
31    pub revision: u64,
32    pub state: OperationState,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub progress: Option<TProgress>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub output: Option<TOutput>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "camelCase")]
41struct AcceptedEnvelope<TProgress = Value, TOutput = Value> {
42    kind: String,
43    #[serde(rename = "ref")]
44    operation_ref: OperationRefData,
45    snapshot: OperationSnapshot<TProgress, TOutput>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(rename_all = "camelCase")]
50struct SnapshotFrame<TProgress = Value, TOutput = Value> {
51    kind: String,
52    snapshot: OperationSnapshot<TProgress, TOutput>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56#[serde(tag = "type", rename_all = "lowercase")]
57pub enum OperationEvent<TProgress = Value, TOutput = Value> {
58    Accepted {
59        snapshot: OperationSnapshot<TProgress, TOutput>,
60    },
61    Started {
62        snapshot: OperationSnapshot<TProgress, TOutput>,
63    },
64    Progress {
65        snapshot: OperationSnapshot<TProgress, TOutput>,
66    },
67    Completed {
68        snapshot: OperationSnapshot<TProgress, TOutput>,
69    },
70    Failed {
71        snapshot: OperationSnapshot<TProgress, TOutput>,
72    },
73    Cancelled {
74        snapshot: OperationSnapshot<TProgress, TOutput>,
75    },
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80struct EventFrame<TProgress = Value, TOutput = Value> {
81    kind: String,
82    event: OperationEvent<TProgress, TOutput>,
83}
84
85pub trait OperationDescriptor {
86    type Input: Serialize;
87    type Progress: DeserializeOwned + Send + 'static;
88    type Output: DeserializeOwned + Send + 'static;
89
90    const KEY: &'static str;
91    const SUBJECT: &'static str;
92    const CALLER_CAPABILITIES: &'static [&'static str];
93    const READ_CAPABILITIES: &'static [&'static str];
94    const CANCEL_CAPABILITIES: &'static [&'static str];
95    const CANCELABLE: bool;
96}
97
98#[doc(hidden)]
99pub trait OperationTransport {
100    fn request_json_value<'a>(
101        &'a self,
102        subject: String,
103        body: Value,
104    ) -> impl Future<Output = Result<Value, TrellisClientError>> + Send + 'a;
105
106    fn watch_json_value<'a>(
107        &'a self,
108        subject: String,
109        body: Value,
110    ) -> impl Future<
111        Output = Result<BoxStream<'a, Result<Value, TrellisClientError>>, TrellisClientError>,
112    > + Send
113           + 'a;
114}
115
116pub struct OperationInvoker<'a, T, D> {
117    transport: &'a T,
118    _descriptor: PhantomData<D>,
119}
120
121pub struct OperationRef<'a, T, D> {
122    transport: &'a T,
123    data: OperationRefData,
124    _descriptor: PhantomData<D>,
125}
126
127fn is_terminal_state(state: &OperationState) -> bool {
128    matches!(
129        state,
130        OperationState::Completed | OperationState::Failed | OperationState::Cancelled
131    )
132}
133
134impl<'a, T, D> OperationInvoker<'a, T, D> {
135    pub fn new(transport: &'a T) -> Self {
136        Self {
137            transport,
138            _descriptor: PhantomData,
139        }
140    }
141}
142
143impl<'a, T, D> OperationInvoker<'a, T, D>
144where
145    T: OperationTransport,
146    D: OperationDescriptor,
147    D::Progress: Send,
148    D::Output: Send,
149{
150    pub async fn start(
151        &self,
152        input: &D::Input,
153    ) -> Result<OperationRef<'a, T, D>, TrellisClientError> {
154        let body = serde_json::to_value(input)?;
155        let response = self
156            .transport
157            .request_json_value(D::SUBJECT.to_string(), body)
158            .await?;
159        let accepted: AcceptedEnvelope<D::Progress, D::Output> = serde_json::from_value(response)?;
160        if accepted.kind != "accepted" {
161            return Err(TrellisClientError::OperationProtocol(format!(
162                "expected accepted envelope, got '{}'",
163                accepted.kind
164            )));
165        }
166        Ok(OperationRef {
167            transport: self.transport,
168            data: accepted.operation_ref,
169            _descriptor: PhantomData,
170        })
171    }
172}
173
174impl<'a, T, D> OperationRef<'a, T, D> {
175    pub fn id(&self) -> &str {
176        &self.data.id
177    }
178
179    pub fn service(&self) -> &str {
180        &self.data.service
181    }
182
183    pub fn operation(&self) -> &str {
184        &self.data.operation
185    }
186}
187
188impl<'a, T, D> OperationRef<'a, T, D>
189where
190    T: OperationTransport,
191    D: OperationDescriptor,
192{
193    pub async fn get(
194        &self,
195    ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
196        let body = json!({
197            "action": "get",
198            "operationId": self.id(),
199        });
200        let response = self
201            .transport
202            .request_json_value(control_subject(D::SUBJECT), body)
203            .await?;
204        let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
205        if frame.kind != "snapshot" {
206            return Err(TrellisClientError::OperationProtocol(format!(
207                "expected snapshot frame, got '{}'",
208                frame.kind
209            )));
210        }
211        Ok(frame.snapshot)
212    }
213
214    pub async fn wait(
215        &self,
216    ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
217        let body = json!({
218            "action": "wait",
219            "operationId": self.id(),
220        });
221        let response = self
222            .transport
223            .request_json_value(control_subject(D::SUBJECT), body)
224            .await?;
225        let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
226        if frame.kind != "snapshot" {
227            return Err(TrellisClientError::OperationProtocol(format!(
228                "expected snapshot frame, got '{}'",
229                frame.kind
230            )));
231        }
232        if !is_terminal_state(&frame.snapshot.state) {
233            return Err(TrellisClientError::OperationProtocol(
234                "wait returned non-terminal snapshot".to_string(),
235            ));
236        }
237        Ok(frame.snapshot)
238    }
239
240    pub async fn cancel(
241        &self,
242    ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
243        let body = json!({
244            "action": "cancel",
245            "operationId": self.id(),
246        });
247        let response = self
248            .transport
249            .request_json_value(control_subject(D::SUBJECT), body)
250            .await?;
251        let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
252        if frame.kind != "snapshot" {
253            return Err(TrellisClientError::OperationProtocol(format!(
254                "expected snapshot frame, got '{}'",
255                frame.kind
256            )));
257        }
258        Ok(frame.snapshot)
259    }
260
261    pub async fn watch(
262        &self,
263    ) -> Result<
264        BoxStream<'a, Result<OperationEvent<D::Progress, D::Output>, TrellisClientError>>,
265        TrellisClientError,
266    > {
267        let control = control_subject(D::SUBJECT);
268        let body = json!({
269            "action": "watch",
270            "operationId": self.id(),
271        });
272        let response = self.transport.watch_json_value(control, body).await?;
273        Ok(Box::pin(stream::try_unfold(
274            (response, false),
275            |(mut response, done)| async move {
276                if done {
277                    return Ok(None);
278                }
279
280                loop {
281                    match response.next().await {
282                        Some(frame) => {
283                            let event = match frame {
284                                Ok(value) => {
285                                    match decode_watch_frame::<D::Progress, D::Output>(value) {
286                                        Ok(Some(event)) => event,
287                                        Ok(None) => continue,
288                                        Err(error) => return Err(error),
289                                    }
290                                }
291                                Err(error) => return Err(error),
292                            };
293
294                            let terminal = is_terminal_event(&event);
295                            return Ok(Some((event, (response, terminal))));
296                        }
297                        None => return Ok(None),
298                    }
299                }
300            },
301        )))
302    }
303}
304
305fn decode_watch_frame<TProgress: DeserializeOwned, TOutput: DeserializeOwned>(
306    value: Value,
307) -> Result<Option<OperationEvent<TProgress, TOutput>>, TrellisClientError> {
308    if value.get("kind").and_then(Value::as_str) == Some("keepalive") {
309        return Ok(None);
310    }
311
312    let kind = value.get("kind").and_then(Value::as_str).ok_or_else(|| {
313        TrellisClientError::OperationProtocol("expected watch frame kind".to_string())
314    })?;
315
316    match kind {
317        "snapshot" => {
318            let frame: SnapshotFrame<TProgress, TOutput> = serde_json::from_value(value)?;
319            Ok(Some(snapshot_to_event(frame.snapshot)))
320        }
321        "event" => {
322            let frame: EventFrame<TProgress, TOutput> = serde_json::from_value(value)?;
323            Ok(Some(frame.event))
324        }
325        _ => Err(TrellisClientError::OperationProtocol(
326            "expected snapshot/event/keepalive frame".to_string(),
327        )),
328    }
329}
330
331fn snapshot_to_event<TProgress, TOutput>(
332    snapshot: OperationSnapshot<TProgress, TOutput>,
333) -> OperationEvent<TProgress, TOutput> {
334    match snapshot.state {
335        OperationState::Pending => OperationEvent::Accepted { snapshot },
336        OperationState::Running => OperationEvent::Started { snapshot },
337        OperationState::Completed => OperationEvent::Completed { snapshot },
338        OperationState::Failed => OperationEvent::Failed { snapshot },
339        OperationState::Cancelled => OperationEvent::Cancelled { snapshot },
340    }
341}
342
343fn is_terminal_event<TProgress, TOutput>(event: &OperationEvent<TProgress, TOutput>) -> bool {
344    matches!(
345        event,
346        OperationEvent::Completed { .. }
347            | OperationEvent::Failed { .. }
348            | OperationEvent::Cancelled { .. }
349    )
350}
351
352pub fn control_subject(subject: &str) -> String {
353    format!("{subject}.control")
354}