1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
//! Concrete [`PeerHandle`] implementation that delegates to the
//! [`ServerRequestDispatcher`].
//!
//! `DispatchPeerHandle` does NOT own a channel. It holds an
//! `Arc<ServerRequestDispatcher>` and delegates every outbound RPC to
//! `dispatcher.dispatch(...)`. The dispatcher owns the correlation layer
//! (pending oneshot map keyed by correlation id) and the drain-to-transport
//! task. This avoids the anti-pattern of ad-hoc per-site channel
//! construction: every peer handle shares the single correlation authority.
//!
//! Deserialization: the dispatcher returns `serde_json::Value`; the
//! `DispatchPeerHandle` parses into the typed result and surfaces malformed
//! responses as a protocol `INTERNAL_ERROR`.
#![cfg(not(target_arch = "wasm32"))]
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{Error, ErrorCode, Result};
use crate::server::roots::ListRootsResult;
use crate::server::server_request_dispatcher::ServerRequestDispatcher;
use crate::shared::peer::PeerHandle;
use crate::types::sampling::{CreateMessageParams, CreateMessageResult};
use crate::types::{ProgressToken, ServerRequest};
/// [`PeerHandle`] that delegates outbound RPCs to a shared
/// [`ServerRequestDispatcher`].
///
/// Constructed fresh-per-request at each `ServerCore` dispatch site when
/// the enclosing `ServerCore` was built with
/// [`crate::server::core::ServerCore::with_server_request_dispatcher`].
/// The construction is near-zero-cost — the struct is a single `Arc`
/// clone — so per-request allocation is not a concern.
#[derive(Debug)]
pub struct DispatchPeerHandle {
dispatcher: Arc<ServerRequestDispatcher>,
}
impl DispatchPeerHandle {
/// Build a peer handle around a shared dispatcher.
///
/// Pub (not `pub(crate)`) so the `#[doc(hidden)] __test_support`
/// re-export in `src/lib.rs` can link from integration tests; the
/// enclosing `peer_impl` module is `pub(crate)`, so this stays
/// internal from a doc/discoverability standpoint.
pub fn new(dispatcher: Arc<ServerRequestDispatcher>) -> Self {
Self { dispatcher }
}
}
#[async_trait]
impl PeerHandle for DispatchPeerHandle {
async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
let value = self
.dispatcher
.dispatch(ServerRequest::CreateMessage(Box::new(params)))
.await?;
serde_json::from_value::<CreateMessageResult>(value).map_err(|e| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Invalid sample response: {e}"),
)
})
}
async fn list_roots(&self) -> Result<ListRootsResult> {
let value = self.dispatcher.dispatch(ServerRequest::ListRoots).await?;
serde_json::from_value::<ListRootsResult>(value).map_err(|e| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Invalid list_roots response: {e}"),
)
})
}
async fn progress_notify(
&self,
_token: ProgressToken,
_progress: f64,
_total: Option<f64>,
_message: Option<String>,
) -> Result<()> {
// Progress is a notification (one-way, no response) not a
// request/response. The existing `Server::notification_tx:
// Sender<Notification>` channel is the right vehicle, but
// DispatchPeerHandle doesn't hold a clone. For this phase we
// preserve the existing `RequestHandlerExtra::report_progress`
// no-op behavior: return Ok(()) silently. Follow-on work can plumb
// notification_tx through DispatchPeerHandle for live progress.
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::sync::mpsc;
fn build_dispatcher_with_short_timeout() -> (
Arc<ServerRequestDispatcher>,
mpsc::Receiver<(String, ServerRequest)>,
) {
let (tx, rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher = Arc::new(
ServerRequestDispatcher::new_with_channel(tx).with_timeout(Duration::from_millis(40)),
);
(dispatcher, rx)
}
#[tokio::test]
async fn test_peer_handle_trait_shape() {
let (dispatcher, _rx) = build_dispatcher_with_short_timeout();
let peer: Arc<dyn PeerHandle> = Arc::new(DispatchPeerHandle::new(dispatcher));
// Trait-shape smoke: casts to Arc<dyn PeerHandle>. The Arc itself
// can be cloned and stored — no ?Sized errors.
let _clone = peer.clone();
}
#[tokio::test]
async fn test_peer_progress_notify_always_ok() {
let (dispatcher, _rx) = build_dispatcher_with_short_timeout();
let peer = DispatchPeerHandle::new(dispatcher);
let result = peer
.progress_notify(
ProgressToken::String("tok-1".to_string()),
0.5,
Some(1.0),
None,
)
.await;
assert!(result.is_ok(), "progress_notify is a no-op for this phase");
}
#[tokio::test]
async fn test_peer_sample_propagates_dispatcher_timeout() {
let (dispatcher, _rx) = build_dispatcher_with_short_timeout();
let peer = DispatchPeerHandle::new(dispatcher);
// Use REAL constructor — CreateMessageParams has no Default impl.
let params = CreateMessageParams::new(Vec::new());
let start = std::time::Instant::now();
let result = peer.sample(params).await;
let elapsed = start.elapsed();
assert!(
result.is_err(),
"sample must return Err when dispatcher times out"
);
assert!(
elapsed < Duration::from_millis(500),
"timeout must fire within 500ms (was {:?})",
elapsed
);
}
}