Skip to main content

agent_client_protocol/
util.rs

1// Types re-exported from crate root
2
3use futures::{
4    future::BoxFuture,
5    stream::{Stream, StreamExt},
6};
7
8mod typed;
9pub use typed::{MatchDispatch, MatchDispatchFrom, TypeNotification};
10
11/// Cast from `N` to `M` by serializing/deserialization to/from JSON.
12pub fn json_cast<N, M>(params: N) -> Result<M, crate::Error>
13where
14    N: serde::Serialize,
15    M: serde::de::DeserializeOwned,
16{
17    let json = serde_json::to_value(params).map_err(|e| {
18        crate::Error::parse_error().data(serde_json::json!({
19            "error": e.to_string(),
20            "phase": "serialization"
21        }))
22    })?;
23    let m = serde_json::from_value(json.clone()).map_err(|e| {
24        crate::Error::parse_error().data(serde_json::json!({
25            "error": e.to_string(),
26            "json": json,
27            "phase": "deserialization"
28        }))
29    })?;
30    Ok(m)
31}
32
33/// Cast incoming request/notification params into a typed payload.
34///
35/// Like [`json_cast`], but deserialization failures become
36/// [`Error::invalid_params`](`crate::Error::invalid_params`) (`-32602`)
37/// instead of a parse error, which is the correct JSON-RPC error code for
38/// malformed method parameters.
39pub fn json_cast_params<N, M>(params: N) -> Result<M, crate::Error>
40where
41    N: serde::Serialize,
42    M: serde::de::DeserializeOwned,
43{
44    let json = serde_json::to_value(params).map_err(|e| {
45        crate::Error::internal_error().data(serde_json::json!({
46            "error": e.to_string(),
47            "phase": "serialization"
48        }))
49    })?;
50    let m = serde_json::from_value(json.clone()).map_err(|e| {
51        crate::Error::invalid_params().data(serde_json::json!({
52            "error": e.to_string(),
53            "json": json,
54            "phase": "deserialization"
55        }))
56    })?;
57    Ok(m)
58}
59
60/// Creates an internal error with the given message
61pub fn internal_error(message: impl ToString) -> crate::Error {
62    crate::Error::internal_error().data(message.to_string())
63}
64
65/// Creates a parse error with the given message
66pub fn parse_error(message: impl ToString) -> crate::Error {
67    crate::Error::parse_error().data(message.to_string())
68}
69
70/// Convert a JSON-RPC id to a serde_json::Value.
71pub(crate) fn id_to_json(id: &crate::schema::v1::RequestId) -> serde_json::Value {
72    serde_json::to_value(id).expect("RequestId serializes infallibly")
73}
74
75pub(crate) fn instrumented_with_connection_name<F>(
76    name: String,
77    task: F,
78) -> tracing::instrument::Instrumented<F> {
79    use tracing::Instrument;
80
81    task.instrument(tracing::info_span!("connection", name = name))
82}
83
84pub(crate) async fn instrument_with_connection_name<R>(
85    name: Option<String>,
86    task: impl Future<Output = R>,
87) -> R {
88    if let Some(name) = name {
89        instrumented_with_connection_name(name.clone(), task).await
90    } else {
91        task.await
92    }
93}
94
95/// Run two fallible futures concurrently, returning when both complete successfully
96/// or when either fails.
97pub async fn both<E>(
98    a: impl Future<Output = Result<(), E>>,
99    b: impl Future<Output = Result<(), E>>,
100) -> Result<(), E> {
101    let ((), ()) = futures::future::try_join(a, b).await?;
102    Ok(())
103}
104
105/// Run `background` until `foreground` completes.
106///
107/// Returns the result of `foreground`. If `background` errors before
108/// `foreground` completes, the error is propagated. If `background`
109/// completes with `Ok(())`, we continue waiting for `foreground`.
110pub async fn run_until<T, E>(
111    background: impl Future<Output = Result<(), E>>,
112    foreground: impl Future<Output = Result<T, E>>,
113) -> Result<T, E> {
114    use futures::future::{Either, select};
115    use std::pin::pin;
116
117    match select(pin!(background), pin!(foreground)).await {
118        Either::Left((bg_result, fg_future)) => {
119            // Background finished first
120            bg_result?; // propagate error, or if Ok(()), keep waiting
121            fg_future.await
122        }
123        Either::Right((fg_result, _bg_future)) => {
124            // Foreground finished first, drop background
125            fg_result
126        }
127    }
128}
129
130/// Process items from a stream concurrently.
131///
132/// For each item received from `stream`, calls `process_fn` to create a future,
133/// then runs all futures concurrently. If any future returns an error,
134/// stops processing and returns that error.
135///
136/// This is useful for patterns where you receive work items from a channel
137/// and want to process them concurrently while respecting backpressure.
138pub async fn process_stream_concurrently<T, F>(
139    stream: impl Stream<Item = T>,
140    process_fn: F,
141    process_fn_hack: impl for<'a> Fn(&'a F, T) -> BoxFuture<'a, Result<(), crate::Error>>,
142) -> Result<(), crate::Error>
143where
144    F: AsyncFn(T) -> Result<(), crate::Error>,
145{
146    use std::pin::pin;
147
148    use futures::stream::{FusedStream, FuturesUnordered};
149    use futures_concurrency::future::Race;
150
151    enum Event<T> {
152        NewItem(Option<T>),
153        FutureCompleted(Option<Result<(), crate::Error>>),
154    }
155
156    let mut stream = pin!(stream.fuse());
157    let mut futures: FuturesUnordered<_> = FuturesUnordered::new();
158
159    loop {
160        // If we have no futures to run, wait until we do.
161        if futures.is_empty() {
162            match stream.next().await {
163                Some(item) => futures.push(process_fn_hack(&process_fn, item)),
164                None => return Ok(()),
165            }
166            continue;
167        }
168
169        // If there are no more items coming in, just drain our queue and return.
170        if stream.is_terminated() {
171            while let Some(result) = futures.next().await {
172                result?;
173            }
174            return Ok(());
175        }
176
177        // Otherwise, race between getting a new item and completing a future.
178        let event = (async { Event::NewItem(stream.next().await) }, async {
179            Event::FutureCompleted(futures.next().await)
180        })
181            .race()
182            .await;
183
184        match event {
185            Event::NewItem(Some(item)) => {
186                futures.push(process_fn_hack(&process_fn, item));
187            }
188            Event::FutureCompleted(Some(result)) => {
189                result?;
190            }
191            Event::NewItem(None) | Event::FutureCompleted(None) => {
192                // Stream closed, loop will catch is_terminated
193                // No futures were pending, shouldn't happen since we checked is_empty
194            }
195        }
196    }
197}