dynamo_runtime/protocols/
annotated.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5use crate::{Result, error};
6use maybe_error::MaybeError;
7
8pub trait AnnotationsProvider {
9    fn annotations(&self) -> Option<Vec<String>>;
10    fn has_annotation(&self, annotation: &str) -> bool {
11        self.annotations()
12            .map(|annotations| annotations.iter().any(|a| a == annotation))
13            .unwrap_or(false)
14    }
15}
16
17/// Our services have the option of returning an "annotated" stream, which allows use
18/// to include additional information with each delta. This is useful for debugging,
19/// performance benchmarking, and improved observability.
20#[derive(Serialize, Deserialize, Clone, Debug)]
21pub struct Annotated<R> {
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub data: Option<R>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub id: Option<String>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub event: Option<String>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub comment: Option<Vec<String>>,
30}
31
32impl<R> Annotated<R> {
33    /// Create a new annotated stream from the given error
34    pub fn from_error(error: String) -> Self {
35        Self {
36            data: None,
37            id: None,
38            event: Some("error".to_string()),
39            comment: Some(vec![error]),
40        }
41    }
42
43    /// Create a new annotated stream from the given data
44    pub fn from_data(data: R) -> Self {
45        Self {
46            data: Some(data),
47            id: None,
48            event: None,
49            comment: None,
50        }
51    }
52
53    /// Add an annotation to the stream
54    ///
55    /// Annotations populate the `event` field and the `comment` field
56    pub fn from_annotation<S: Serialize>(
57        name: impl Into<String>,
58        value: &S,
59    ) -> Result<Self, serde_json::Error> {
60        Ok(Self {
61            data: None,
62            id: None,
63            event: Some(name.into()),
64            comment: Some(vec![serde_json::to_string(value)?]),
65        })
66    }
67
68    /// Convert to a [`Result<Self, String>`]
69    /// If [`Self::event`] is "error", return an error message(s) held by [`Self::comment`]
70    pub fn ok(self) -> Result<Self, String> {
71        if let Some(event) = &self.event
72            && event == "error"
73        {
74            return Err(self
75                .comment
76                .unwrap_or(vec!["unknown error".to_string()])
77                .join(", "));
78        }
79        Ok(self)
80    }
81
82    pub fn is_ok(&self) -> bool {
83        self.event.as_deref() != Some("error")
84    }
85
86    pub fn is_err(&self) -> bool {
87        !self.is_ok()
88    }
89
90    pub fn is_event(&self) -> bool {
91        self.event.is_some()
92    }
93
94    pub fn transfer<U: Serialize>(self, data: Option<U>) -> Annotated<U> {
95        Annotated::<U> {
96            data,
97            id: self.id,
98            event: self.event,
99            comment: self.comment,
100        }
101    }
102
103    /// Apply a mapping/transformation to the data field
104    /// If the mapping fails, the error is returned as an annotated stream
105    pub fn map_data<U, F>(self, transform: F) -> Annotated<U>
106    where
107        F: FnOnce(R) -> Result<U, String>,
108    {
109        match self.data.map(transform).transpose() {
110            Ok(data) => Annotated::<U> {
111                data,
112                id: self.id,
113                event: self.event,
114                comment: self.comment,
115            },
116            Err(e) => Annotated::from_error(e),
117        }
118    }
119
120    pub fn is_error(&self) -> bool {
121        self.event.as_deref() == Some("error")
122    }
123
124    pub fn into_result(self) -> Result<Option<R>> {
125        match self.data {
126            Some(data) => Ok(Some(data)),
127            None => match self.event {
128                Some(event) if event == "error" => Err(error!(
129                    self.comment
130                        .unwrap_or(vec!["unknown error".to_string()])
131                        .join(", ")
132                ))?,
133                _ => Ok(None),
134            },
135        }
136    }
137}
138
139impl<R> MaybeError for Annotated<R>
140where
141    R: for<'de> Deserialize<'de> + Serialize,
142{
143    fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
144        Annotated::from_error(format!("{:?}", err))
145    }
146
147    fn err(&self) -> Option<anyhow::Error> {
148        if self.is_error() {
149            if let Some(comment) = &self.comment
150                && !comment.is_empty()
151            {
152                return Some(anyhow::Error::msg(comment.join("; ")));
153            }
154            Some(anyhow::Error::msg("unknown error"))
155        } else {
156            None
157        }
158    }
159}
160
161// impl<R> Annotated<R>
162// where
163//     R: for<'de> Deserialize<'de> + Serialize,
164// {
165//     pub fn convert_sse_stream(
166//         stream: DataStream<Result<Message, SseCodecError>>,
167//     ) -> DataStream<Annotated<R>> {
168//         let stream = stream.map(|message| match message {
169//             Ok(message) => {
170//                 let delta = Annotated::<R>::try_from(message);
171//                 match delta {
172//                     Ok(delta) => delta,
173//                     Err(e) => Annotated::from_error(e.to_string()),
174//                 }
175//             }
176//             Err(e) => Annotated::from_error(e.to_string()),
177//         });
178//         Box::pin(stream)
179//     }
180// }
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_maybe_error() {
188        let annotated = Annotated::from_data("Test data".to_string());
189        assert!(annotated.err().is_none());
190        assert!(annotated.is_ok());
191
192        let annotated = Annotated::<String>::from_error("Test error 2".to_string());
193        assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 2");
194        assert!(annotated.is_err());
195
196        let annotated =
197            Annotated::<String>::from_err(anyhow::Error::msg("Test error 3".to_string()).into());
198        assert!(annotated.is_err());
199    }
200}