dynamo_runtime/protocols/
annotated.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use crate::{error, Result};
18use maybe_error::MaybeError;
19
20pub trait AnnotationsProvider {
21    fn annotations(&self) -> Option<Vec<String>>;
22    fn has_annotation(&self, annotation: &str) -> bool {
23        self.annotations()
24            .map(|annotations| annotations.iter().any(|a| a == annotation))
25            .unwrap_or(false)
26    }
27}
28
29/// Our services have the option of returning an "annotated" stream, which allows use
30/// to include additional information with each delta. This is useful for debugging,
31/// performance benchmarking, and improved observability.
32#[derive(Serialize, Deserialize, Clone, Debug)]
33pub struct Annotated<R> {
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub data: Option<R>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub id: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub event: Option<String>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub comment: Option<Vec<String>>,
42}
43
44impl<R> Annotated<R> {
45    /// Create a new annotated stream from the given error
46    pub fn from_error(error: String) -> Self {
47        Self {
48            data: None,
49            id: None,
50            event: Some("error".to_string()),
51            comment: Some(vec![error]),
52        }
53    }
54
55    /// Create a new annotated stream from the given data
56    pub fn from_data(data: R) -> Self {
57        Self {
58            data: Some(data),
59            id: None,
60            event: None,
61            comment: None,
62        }
63    }
64
65    /// Add an annotation to the stream
66    ///
67    /// Annotations populate the `event` field and the `comment` field
68    pub fn from_annotation<S: Serialize>(
69        name: impl Into<String>,
70        value: &S,
71    ) -> Result<Self, serde_json::Error> {
72        Ok(Self {
73            data: None,
74            id: None,
75            event: Some(name.into()),
76            comment: Some(vec![serde_json::to_string(value)?]),
77        })
78    }
79
80    /// Convert to a [`Result<Self, String>`]
81    /// If [`Self::event`] is "error", return an error message(s) held by [`Self::comment`]
82    pub fn ok(self) -> Result<Self, String> {
83        if let Some(event) = &self.event {
84            if event == "error" {
85                return Err(self
86                    .comment
87                    .unwrap_or(vec!["unknown error".to_string()])
88                    .join(", "));
89            }
90        }
91        Ok(self)
92    }
93
94    pub fn is_ok(&self) -> bool {
95        self.event.as_deref() != Some("error")
96    }
97
98    pub fn is_err(&self) -> bool {
99        !self.is_ok()
100    }
101
102    pub fn is_event(&self) -> bool {
103        self.event.is_some()
104    }
105
106    pub fn transfer<U: Serialize>(self, data: Option<U>) -> Annotated<U> {
107        Annotated::<U> {
108            data,
109            id: self.id,
110            event: self.event,
111            comment: self.comment,
112        }
113    }
114
115    /// Apply a mapping/transformation to the data field
116    /// If the mapping fails, the error is returned as an annotated stream
117    pub fn map_data<U, F>(self, transform: F) -> Annotated<U>
118    where
119        F: FnOnce(R) -> Result<U, String>,
120    {
121        match self.data.map(transform).transpose() {
122            Ok(data) => Annotated::<U> {
123                data,
124                id: self.id,
125                event: self.event,
126                comment: self.comment,
127            },
128            Err(e) => Annotated::from_error(e),
129        }
130    }
131
132    pub fn is_error(&self) -> bool {
133        self.event.as_deref() == Some("error")
134    }
135
136    pub fn into_result(self) -> Result<Option<R>> {
137        match self.data {
138            Some(data) => Ok(Some(data)),
139            None => match self.event {
140                Some(event) if event == "error" => Err(error!(self
141                    .comment
142                    .unwrap_or(vec!["unknown error".to_string()])
143                    .join(", ")))?,
144                _ => Ok(None),
145            },
146        }
147    }
148}
149
150impl<R> MaybeError for Annotated<R>
151where
152    R: for<'de> Deserialize<'de> + Serialize,
153{
154    fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
155        Annotated::from_error(format!("{:?}", err))
156    }
157
158    fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
159        if self.is_error() {
160            if let Some(comment) = &self.comment {
161                if !comment.is_empty() {
162                    return Some(anyhow::Error::msg(comment.join("; ")).into());
163                }
164            }
165            Some(anyhow::Error::msg("unknown error").into())
166        } else {
167            None
168        }
169    }
170}
171
172// impl<R> Annotated<R>
173// where
174//     R: for<'de> Deserialize<'de> + Serialize,
175// {
176//     pub fn convert_sse_stream(
177//         stream: DataStream<Result<Message, SseCodecError>>,
178//     ) -> DataStream<Annotated<R>> {
179//         let stream = stream.map(|message| match message {
180//             Ok(message) => {
181//                 let delta = Annotated::<R>::try_from(message);
182//                 match delta {
183//                     Ok(delta) => delta,
184//                     Err(e) => Annotated::from_error(e.to_string()),
185//                 }
186//             }
187//             Err(e) => Annotated::from_error(e.to_string()),
188//         });
189//         Box::pin(stream)
190//     }
191// }
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_maybe_error() {
199        let annotated = Annotated::from_data("Test data".to_string());
200        assert!(annotated.err().is_none());
201        assert!(annotated.is_ok());
202        assert!(!annotated.is_err());
203
204        let annotated = Annotated::<String>::from_error("Test error 2".to_string());
205        assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 2");
206        assert!(!annotated.is_ok());
207        assert!(annotated.is_err());
208
209        let annotated =
210            Annotated::<String>::from_err(anyhow::Error::msg("Test error 3".to_string()).into());
211        assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 3");
212        assert!(!annotated.is_ok());
213        assert!(annotated.is_err());
214    }
215}