Skip to main content

agent_client_protocol/
util.rs

1// Types re-exported from crate root
2
3use futures::{
4    future::BoxFuture,
5    stream::{Stream, StreamExt},
6};
7
8mod typed;
9pub use typed::{MatchDispatch, MatchDispatchFrom, TypeNotification};
10
11/// Cast from `N` to `M` by serializing/deserialization to/from JSON.
12pub fn json_cast<N, M>(params: N) -> Result<M, crate::Error>
13where
14    N: serde::Serialize,
15    M: serde::de::DeserializeOwned,
16{
17    let json = serde_json::to_value(params).map_err(|e| {
18        crate::Error::parse_error().data(serde_json::json!({
19            "error": e.to_string(),
20            "phase": "serialization"
21        }))
22    })?;
23    let m = serde_json::from_value(json.clone()).map_err(|e| {
24        crate::Error::parse_error().data(serde_json::json!({
25            "error": e.to_string(),
26            "json": json,
27            "phase": "deserialization"
28        }))
29    })?;
30    Ok(m)
31}
32
33/// Cast incoming request/notification params into a typed payload.
34///
35/// Like [`json_cast`], but deserialization failures become
36/// [`Error::invalid_params`](`crate::Error::invalid_params`) (`-32602`)
37/// instead of a parse error, which is the correct JSON-RPC error code for
38/// malformed method parameters.
39pub fn json_cast_params<N, M>(params: N) -> Result<M, crate::Error>
40where
41    N: serde::Serialize,
42    M: serde::de::DeserializeOwned,
43{
44    let json = serde_json::to_value(params).map_err(|e| {
45        crate::Error::internal_error().data(serde_json::json!({
46            "error": e.to_string(),
47            "phase": "serialization"
48        }))
49    })?;
50    let m = serde_json::from_value(json.clone()).map_err(|e| {
51        crate::Error::invalid_params().data(serde_json::json!({
52            "error": e.to_string(),
53            "json": json,
54            "phase": "deserialization"
55        }))
56    })?;
57    Ok(m)
58}
59
60/// Creates an internal error with the given message
61pub fn internal_error(message: impl ToString) -> crate::Error {
62    crate::Error::internal_error().data(message.to_string())
63}
64
65/// Creates a parse error with the given message
66pub fn parse_error(message: impl ToString) -> crate::Error {
67    crate::Error::parse_error().data(message.to_string())
68}
69
70/// Convert a JSON-RPC id to a serde_json::Value.
71pub(crate) fn id_to_json(id: &jsonrpcmsg::Id) -> serde_json::Value {
72    match id {
73        jsonrpcmsg::Id::Number(n) => serde_json::Value::Number((*n).into()),
74        jsonrpcmsg::Id::String(s) => serde_json::Value::String(s.clone()),
75        jsonrpcmsg::Id::Null => serde_json::Value::Null,
76    }
77}
78
79pub(crate) fn instrumented_with_connection_name<F>(
80    name: String,
81    task: F,
82) -> tracing::instrument::Instrumented<F> {
83    use tracing::Instrument;
84
85    task.instrument(tracing::info_span!("connection", name = name))
86}
87
88pub(crate) async fn instrument_with_connection_name<R>(
89    name: Option<String>,
90    task: impl Future<Output = R>,
91) -> R {
92    if let Some(name) = name {
93        instrumented_with_connection_name(name.clone(), task).await
94    } else {
95        task.await
96    }
97}
98
99/// Convert a `crate::Error` into a `crate::jsonrpcmsg::Error`
100#[must_use]
101pub fn into_jsonrpc_error(err: crate::Error) -> crate::jsonrpcmsg::Error {
102    crate::jsonrpcmsg::Error {
103        code: err.code.into(),
104        message: err.message,
105        data: err.data,
106    }
107}
108
109/// Run two fallible futures concurrently, returning when both complete successfully
110/// or when either fails.
111pub async fn both<E>(
112    a: impl Future<Output = Result<(), E>>,
113    b: impl Future<Output = Result<(), E>>,
114) -> Result<(), E> {
115    let ((), ()) = futures::future::try_join(a, b).await?;
116    Ok(())
117}
118
119/// Run `background` until `foreground` completes.
120///
121/// Returns the result of `foreground`. If `background` errors before
122/// `foreground` completes, the error is propagated. If `background`
123/// completes with `Ok(())`, we continue waiting for `foreground`.
124pub async fn run_until<T, E>(
125    background: impl Future<Output = Result<(), E>>,
126    foreground: impl Future<Output = Result<T, E>>,
127) -> Result<T, E> {
128    use futures::future::{Either, select};
129    use std::pin::pin;
130
131    match select(pin!(background), pin!(foreground)).await {
132        Either::Left((bg_result, fg_future)) => {
133            // Background finished first
134            bg_result?; // propagate error, or if Ok(()), keep waiting
135            fg_future.await
136        }
137        Either::Right((fg_result, _bg_future)) => {
138            // Foreground finished first, drop background
139            fg_result
140        }
141    }
142}
143
144/// Process items from a stream concurrently.
145///
146/// For each item received from `stream`, calls `process_fn` to create a future,
147/// then runs all futures concurrently. If any future returns an error,
148/// stops processing and returns that error.
149///
150/// This is useful for patterns where you receive work items from a channel
151/// and want to process them concurrently while respecting backpressure.
152pub async fn process_stream_concurrently<T, F>(
153    stream: impl Stream<Item = T>,
154    process_fn: F,
155    process_fn_hack: impl for<'a> Fn(&'a F, T) -> BoxFuture<'a, Result<(), crate::Error>>,
156) -> Result<(), crate::Error>
157where
158    F: AsyncFn(T) -> Result<(), crate::Error>,
159{
160    use std::pin::pin;
161
162    use futures::stream::{FusedStream, FuturesUnordered};
163    use futures_concurrency::future::Race;
164
165    enum Event<T> {
166        NewItem(Option<T>),
167        FutureCompleted(Option<Result<(), crate::Error>>),
168    }
169
170    let mut stream = pin!(stream.fuse());
171    let mut futures: FuturesUnordered<_> = FuturesUnordered::new();
172
173    loop {
174        // If we have no futures to run, wait until we do.
175        if futures.is_empty() {
176            match stream.next().await {
177                Some(item) => futures.push(process_fn_hack(&process_fn, item)),
178                None => return Ok(()),
179            }
180            continue;
181        }
182
183        // If there are no more items coming in, just drain our queue and return.
184        if stream.is_terminated() {
185            while let Some(result) = futures.next().await {
186                result?;
187            }
188            return Ok(());
189        }
190
191        // Otherwise, race between getting a new item and completing a future.
192        let event = (async { Event::NewItem(stream.next().await) }, async {
193            Event::FutureCompleted(futures.next().await)
194        })
195            .race()
196            .await;
197
198        match event {
199            Event::NewItem(Some(item)) => {
200                futures.push(process_fn_hack(&process_fn, item));
201            }
202            Event::FutureCompleted(Some(result)) => {
203                result?;
204            }
205            Event::NewItem(None) | Event::FutureCompleted(None) => {
206                // Stream closed, loop will catch is_terminated
207                // No futures were pending, shouldn't happen since we checked is_empty
208            }
209        }
210    }
211}