Skip to main content

github_copilot_sdk/
copilot_request_handler.rs

1//! Connection-level interception of the model-layer HTTP and WebSocket traffic
2//! the runtime issues — for both CAPI and BYOK sessions.
3//!
4//! When [`ClientOptions::request_handler`](crate::ClientOptions::request_handler)
5//! is set, the SDK registers itself as the runtime's request handler on
6//! [`Client::start`](crate::Client::start). From then on, whenever the runtime
7//! would issue a model-layer request (inference, `/models`, `/policy`, …) it
8//! asks the registered [`CopilotRequestHandler`] to service it instead of making
9//! the call itself.
10//!
11//! [`CopilotRequestHandler`] is the single seam consumers implement: one HTTP
12//! send method and one WebSocket factory, each defaulting to transparent
13//! pass-through to the real upstream. Override
14//! [`send_request`](CopilotRequestHandler::send_request) to mutate / replace HTTP
15//! requests, or [`open_websocket`](CopilotRequestHandler::open_websocket) to
16//! mutate the handshake or return a custom [`CopilotWebSocketHandler`].
17//!
18//! # Cancellation
19//!
20//! [`CopilotRequestContext::cancel`] fires when the runtime cancels the
21//! in-flight request (for example because the agent turn was aborted). Forward
22//! it to the upstream call so it is torn down too, and stop writing the response.
23
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::{Arc, LazyLock, OnceLock, Weak};
27
28use async_trait::async_trait;
29use base64::Engine;
30use bytes::Bytes;
31use futures_util::{SinkExt, Stream, StreamExt};
32use http::HeaderMap;
33use http::header::{HeaderName, HeaderValue};
34use parking_lot::Mutex;
35use tokio::net::TcpStream;
36use tokio::sync::{Mutex as AsyncMutex, mpsc};
37use tokio_tungstenite::tungstenite::Message;
38use tokio_tungstenite::tungstenite::client::IntoClientRequest;
39use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
40use tokio_util::sync::CancellationToken;
41use tracing::warn;
42
43use crate::generated::api_types::{
44    LlmInferenceHttpRequestChunkRequest, LlmInferenceHttpRequestStartRequest,
45    LlmInferenceHttpRequestStartTransport, LlmInferenceHttpResponseChunkError,
46    LlmInferenceHttpResponseChunkRequest, LlmInferenceHttpResponseStartRequest,
47};
48use crate::{
49    Client, ClientInner, JsonRpcRequest, JsonRpcResponse, RequestId, SessionId, error_codes,
50};
51
52const METHOD_HTTP_REQUEST_START: &str = "llmInference.httpRequestStart";
53const METHOD_HTTP_REQUEST_CHUNK: &str = "llmInference.httpRequestChunk";
54
55/// Transport the runtime would otherwise use for an intercepted request.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57pub enum CopilotRequestTransport {
58    /// Plain HTTP or SSE. Each response body frame is an opaque byte range.
59    #[default]
60    Http,
61    /// Full-duplex WebSocket. Each request/response body frame maps to exactly
62    /// one WebSocket message.
63    WebSocket,
64}
65
66impl CopilotRequestTransport {
67    fn from_wire(value: Option<LlmInferenceHttpRequestStartTransport>) -> Self {
68        match value {
69            Some(LlmInferenceHttpRequestStartTransport::Websocket) => Self::WebSocket,
70            _ => Self::Http,
71        }
72    }
73}
74
75/// Error returned by a [`CopilotRequestHandler`] hook or the response stream.
76#[derive(Debug)]
77#[non_exhaustive]
78pub enum CopilotRequestError {
79    /// The response was used after the RPC connection to the runtime closed.
80    ConnectionClosed,
81
82    /// The response state machine was violated (for example `start` called
83    /// twice, or a write before `start`).
84    InvalidState(String),
85
86    /// An upstream transport failure while forwarding the request.
87    Upstream(String),
88
89    /// A failure surfaced by the consumer's own handler.
90    Handler(String),
91
92    /// An RPC error talking to the runtime.
93    Rpc(crate::Error),
94}
95
96impl CopilotRequestError {
97    /// Construct a handler-level error from a message — the idiomatic way for a
98    /// consumer to fail an intercepted request.
99    pub fn message(message: impl Into<String>) -> Self {
100        Self::Handler(message.into())
101    }
102}
103
104impl std::fmt::Display for CopilotRequestError {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            Self::ConnectionClosed => {
108                f.write_str("Copilot request response used after RPC connection closed")
109            }
110            Self::InvalidState(message) | Self::Upstream(message) | Self::Handler(message) => {
111                f.write_str(message)
112            }
113            Self::Rpc(err) => write!(f, "{err}"),
114        }
115    }
116}
117
118impl std::error::Error for CopilotRequestError {
119    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
120        match self {
121            Self::Rpc(err) => Some(err),
122            _ => None,
123        }
124    }
125}
126
127impl From<crate::Error> for CopilotRequestError {
128    fn from(err: crate::Error) -> Self {
129        Self::Rpc(err)
130    }
131}
132
133/// Context describing an intercepted request, shared by the HTTP and WebSocket
134/// seams.
135#[derive(Clone)]
136#[non_exhaustive]
137pub struct CopilotRequestContext {
138    /// Opaque runtime-minted request id, stable across the request lifecycle.
139    pub request_id: String,
140    /// Id of the runtime session that triggered this request, or `None` when it
141    /// was issued outside any session (for example the startup model catalog).
142    pub session_id: Option<String>,
143    /// Transport the runtime would otherwise use.
144    pub transport: CopilotRequestTransport,
145    /// Absolute request URL.
146    pub url: String,
147    /// Request headers, multi-valued.
148    pub headers: HeaderMap,
149    /// Fires when the runtime cancels this in-flight request.
150    pub cancel: CancellationToken,
151}
152
153/// Streaming response body: a sequence of byte chunks or a terminal error.
154pub type CopilotHttpResponseBody =
155    Pin<Box<dyn Stream<Item = Result<Bytes, CopilotRequestError>> + Send>>;
156
157/// A buffered HTTP request handed to [`CopilotRequestHandler::send_request`].
158#[non_exhaustive]
159pub struct CopilotHttpRequest {
160    /// HTTP method (`GET`, `POST`, …).
161    pub method: String,
162    /// Absolute request URL.
163    pub url: String,
164    /// Request headers.
165    pub headers: HeaderMap,
166    /// Fully-buffered request body.
167    pub body: Vec<u8>,
168    /// Fires when the runtime cancels the request.
169    pub cancel: CancellationToken,
170}
171
172/// A streaming HTTP response returned by [`CopilotRequestHandler::send_request`].
173#[non_exhaustive]
174pub struct CopilotHttpResponse {
175    /// HTTP status code.
176    pub status: u16,
177    /// Optional status reason phrase.
178    pub status_text: Option<String>,
179    /// Response headers.
180    pub headers: HeaderMap,
181    /// Streaming response body.
182    pub body: CopilotHttpResponseBody,
183}
184
185impl CopilotHttpResponse {
186    /// Build a response with the given parts.
187    pub fn new(
188        status: u16,
189        status_text: Option<String>,
190        headers: HeaderMap,
191        body: CopilotHttpResponseBody,
192    ) -> Self {
193        Self {
194            status,
195            status_text,
196            headers,
197            body,
198        }
199    }
200}
201
202/// A single WebSocket message flowing through a [`CopilotWebSocketHandler`].
203#[derive(Clone)]
204pub struct CopilotWebSocketMessage {
205    /// Message payload.
206    pub data: Vec<u8>,
207    /// Whether the payload is a binary frame (`true`) or a text frame (`false`).
208    pub binary: bool,
209}
210
211impl CopilotWebSocketMessage {
212    /// A UTF-8 text message. Binary messages are constructed directly via the
213    /// public `data` / `binary` fields.
214    pub fn from_text(data: impl Into<String>) -> Self {
215        Self {
216            data: data.into().into_bytes(),
217            binary: false,
218        }
219    }
220}
221
222/// The runtime-facing side of a WebSocket: a [`CopilotWebSocketHandler`] writes
223/// upstream→runtime messages here.
224#[derive(Clone)]
225pub struct CopilotWebSocketResponse {
226    exchange: Arc<CopilotRequestExchange>,
227}
228
229impl CopilotWebSocketResponse {
230    fn new(exchange: Arc<CopilotRequestExchange>) -> Self {
231        Self { exchange }
232    }
233
234    /// Forward one upstream message to the runtime.
235    pub async fn send_message(
236        &self,
237        message: CopilotWebSocketMessage,
238    ) -> Result<(), CopilotRequestError> {
239        self.exchange.ensure_ws_started().await?;
240        if message.binary {
241            self.exchange.write_binary(&message.data).await
242        } else {
243            let text = String::from_utf8_lossy(&message.data);
244            self.exchange.write_text(&text).await
245        }
246    }
247
248    /// End the runtime response stream (the upstream connection closed).
249    pub async fn close(&self) -> Result<(), CopilotRequestError> {
250        self.exchange.end_response().await
251    }
252
253    async fn fail(
254        &self,
255        message: impl Into<String>,
256        code: Option<String>,
257    ) -> Result<(), CopilotRequestError> {
258        self.exchange.error_response(message, code).await
259    }
260}
261
262/// A per-connection WebSocket handler. The default implementation
263/// ([`CopilotWebSocketForwarder`]) bridges to the real upstream;
264/// override [`CopilotRequestHandler::open_websocket`] to supply a custom one.
265#[async_trait]
266pub trait CopilotWebSocketHandler: Send + Sync {
267    /// Forward one runtime→upstream message.
268    async fn send_request_message(
269        &self,
270        message: CopilotWebSocketMessage,
271    ) -> Result<(), CopilotRequestError>;
272
273    /// Tear down the upstream connection.
274    async fn close(&self) -> Result<(), CopilotRequestError>;
275}
276
277/// The connection-level Copilot request seam.
278///
279/// One implementor services both transports. Defaults forward transparently to
280/// the real upstream, so overriding nothing yields a pass-through; override a
281/// method to mutate or replace traffic.
282#[async_trait]
283pub trait CopilotRequestHandler: Send + Sync + 'static {
284    /// Service one intercepted HTTP request. Default: forward to the real
285    /// upstream via [`forward_http`]. Override to mutate the request before
286    /// forwarding, mutate the response after, or replace the call entirely.
287    async fn send_request(
288        &self,
289        request: CopilotHttpRequest,
290        _ctx: &CopilotRequestContext,
291    ) -> Result<CopilotHttpResponse, CopilotRequestError> {
292        forward_http(request).await
293    }
294
295    /// Open a per-connection WebSocket handler. Default: a
296    /// [`CopilotWebSocketForwarder`] wired to the real upstream.
297    /// Override to mutate the handshake (URL / headers via `ctx`) or return a
298    /// custom handler.
299    ///
300    /// Unlike the other SDKs, Rust passes `response` — the runtime-facing sink
301    /// for upstream→runtime messages — as a second argument here rather than
302    /// exposing a base-class `send_response_message` helper. A custom handler
303    /// must store this `CopilotWebSocketResponse` in the returned handler struct
304    /// and call [`CopilotWebSocketResponse::send_message`] on it to push
305    /// upstream messages back to the runtime.
306    async fn open_websocket(
307        &self,
308        ctx: &CopilotRequestContext,
309        response: CopilotWebSocketResponse,
310    ) -> Result<Box<dyn CopilotWebSocketHandler>, CopilotRequestError> {
311        let handler = CopilotWebSocketForwarder::builder(ctx.url.clone(), ctx.headers.clone())
312            .connect(response)
313            .await?;
314        Ok(Box::new(handler))
315    }
316}
317
318/// Forward through a shared handler, so an `Arc<H>` can be registered while the
319/// consumer retains a handle (for example to read state the handler records).
320#[async_trait]
321impl<H: CopilotRequestHandler> CopilotRequestHandler for Arc<H> {
322    async fn send_request(
323        &self,
324        request: CopilotHttpRequest,
325        ctx: &CopilotRequestContext,
326    ) -> Result<CopilotHttpResponse, CopilotRequestError> {
327        (**self).send_request(request, ctx).await
328    }
329
330    async fn open_websocket(
331        &self,
332        ctx: &CopilotRequestContext,
333        response: CopilotWebSocketResponse,
334    ) -> Result<Box<dyn CopilotWebSocketHandler>, CopilotRequestError> {
335        (**self).open_websocket(ctx, response).await
336    }
337}
338/// fresh upstream connection.
339const FORBIDDEN_HEADERS: &[&str] = &[
340    "host",
341    "connection",
342    "content-length",
343    "transfer-encoding",
344    "keep-alive",
345    "upgrade",
346    "proxy-connection",
347    "te",
348    "trailer",
349];
350
351fn is_forbidden_header(name: &HeaderName) -> bool {
352    let name = name.as_str();
353    FORBIDDEN_HEADERS.contains(&name) || name.starts_with("sec-websocket")
354}
355
356/// Drop headers that belong to the inbound connection rather than the request.
357fn strip_forbidden_headers(headers: &mut HeaderMap) {
358    let forbidden: Vec<HeaderName> = headers
359        .keys()
360        .filter(|name| is_forbidden_header(name))
361        .cloned()
362        .collect();
363    for name in forbidden {
364        headers.remove(&name);
365    }
366}
367
368static SHARED_HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
369    reqwest::Client::builder()
370        .redirect(reqwest::redirect::Policy::none())
371        .build()
372        .expect("default reqwest client must build")
373});
374
375/// Forward an HTTP request to its real upstream and stream the response back.
376///
377/// This is the default behaviour of [`CopilotRequestHandler::send_request`];
378/// consumers that mutate a request can call it to forward the mutated request.
379pub async fn forward_http(
380    request: CopilotHttpRequest,
381) -> Result<CopilotHttpResponse, CopilotRequestError> {
382    let method = reqwest::Method::from_bytes(request.method.as_bytes())
383        .map_err(|e| CopilotRequestError::InvalidState(format!("invalid HTTP method: {e}")))?;
384
385    let mut headers = request.headers;
386    strip_forbidden_headers(&mut headers);
387
388    let mut builder = SHARED_HTTP_CLIENT
389        .request(method, &request.url)
390        .headers(headers);
391    if !request.body.is_empty() {
392        builder = builder.body(request.body);
393    }
394
395    let response = tokio::select! {
396        _ = request.cancel.cancelled() => {
397            return Err(CopilotRequestError::message("Request cancelled by runtime"));
398        }
399        result = builder.send() => result.map_err(|e| CopilotRequestError::Upstream(e.to_string()))?,
400    };
401
402    let status = response.status().as_u16();
403    let status_text = response.status().canonical_reason().map(str::to_string);
404    let headers = response.headers().clone();
405    let body = response
406        .bytes_stream()
407        .map(|item| item.map_err(|e| CopilotRequestError::Upstream(e.to_string())));
408
409    Ok(CopilotHttpResponse {
410        status,
411        status_text,
412        headers,
413        body: Box::pin(body),
414    })
415}
416
417type UpstreamWrite =
418    futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
419
420/// Transform applied to a WebSocket message; return `None` to drop it.
421pub type WebSocketTransform =
422    Arc<dyn Fn(CopilotWebSocketMessage) -> Option<CopilotWebSocketMessage> + Send + Sync>;
423
424/// Builder for a [`CopilotWebSocketForwarder`].
425pub struct CopilotWebSocketForwarderBuilder {
426    url: String,
427    headers: HeaderMap,
428    on_send_request_message: Option<WebSocketTransform>,
429    on_send_response_message: Option<WebSocketTransform>,
430}
431
432impl CopilotWebSocketForwarderBuilder {
433    /// Hook runtime→upstream messages (mutate or drop before forwarding).
434    pub fn on_send_request_message(mut self, transform: WebSocketTransform) -> Self {
435        self.on_send_request_message = Some(transform);
436        self
437    }
438
439    /// Hook upstream→runtime messages (mutate or drop before forwarding).
440    pub fn on_send_response_message(mut self, transform: WebSocketTransform) -> Self {
441        self.on_send_response_message = Some(transform);
442        self
443    }
444
445    /// Dial the upstream WebSocket and begin pumping upstream→runtime messages
446    /// into `response`.
447    pub async fn connect(
448        self,
449        response: CopilotWebSocketResponse,
450    ) -> Result<CopilotWebSocketForwarder, CopilotRequestError> {
451        let mut request =
452            self.url.as_str().into_client_request().map_err(|e| {
453                CopilotRequestError::Upstream(format!("invalid websocket url: {e}"))
454            })?;
455        for (name, value) in &self.headers {
456            if is_forbidden_header(name) {
457                continue;
458            }
459            request.headers_mut().append(name.clone(), value.clone());
460        }
461
462        let (stream, _) = connect_async(request)
463            .await
464            .map_err(|e| CopilotRequestError::Upstream(format!("websocket connect failed: {e}")))?;
465        let (write, mut read) = stream.split();
466
467        let cancel = CancellationToken::new();
468        let loop_cancel = cancel.clone();
469        let on_response = self.on_send_response_message.clone();
470        tokio::spawn(async move {
471            loop {
472                tokio::select! {
473                    _ = loop_cancel.cancelled() => break,
474                    msg = read.next() => match msg {
475                        Some(Ok(Message::Text(text))) => {
476                            let message = CopilotWebSocketMessage::from_text(text);
477                            if let Some(out) = apply_transform(&on_response, message) {
478                                let _ = response.send_message(out).await;
479                            }
480                        }
481                        Some(Ok(Message::Binary(data))) => {
482                            let message = CopilotWebSocketMessage { data, binary: true };
483                            if let Some(out) = apply_transform(&on_response, message) {
484                                let _ = response.send_message(out).await;
485                            }
486                        }
487                        Some(Ok(Message::Close(_))) | None => break,
488                        Some(Ok(_)) => continue,
489                        Some(Err(e)) => {
490                            let _ = response.fail(e.to_string(), None).await;
491                            return;
492                        }
493                    }
494                }
495            }
496            let _ = response.close().await;
497        });
498
499        Ok(CopilotWebSocketForwarder {
500            write: AsyncMutex::new(Some(write)),
501            on_send_request_message: self.on_send_request_message,
502            cancel,
503        })
504    }
505}
506
507/// The default WebSocket handler: forwards each runtime message to the real
508/// upstream and each upstream message back to the runtime. Mutate by supplying
509/// transforms on the [builder](CopilotWebSocketForwarder::builder).
510pub struct CopilotWebSocketForwarder {
511    write: AsyncMutex<Option<UpstreamWrite>>,
512    on_send_request_message: Option<WebSocketTransform>,
513    cancel: CancellationToken,
514}
515
516impl CopilotWebSocketForwarder {
517    /// Start building a forwarding handler for `url` with the given upstream
518    /// handshake headers.
519    pub fn builder(url: String, headers: HeaderMap) -> CopilotWebSocketForwarderBuilder {
520        CopilotWebSocketForwarderBuilder {
521            url,
522            headers,
523            on_send_request_message: None,
524            on_send_response_message: None,
525        }
526    }
527}
528
529#[async_trait]
530impl CopilotWebSocketHandler for CopilotWebSocketForwarder {
531    async fn send_request_message(
532        &self,
533        message: CopilotWebSocketMessage,
534    ) -> Result<(), CopilotRequestError> {
535        let Some(message) = apply_transform(&self.on_send_request_message, message) else {
536            return Ok(());
537        };
538        let ws_message = if message.binary {
539            Message::Binary(message.data)
540        } else {
541            let text = match String::from_utf8(message.data) {
542                Ok(text) => text,
543                Err(err) => String::from_utf8_lossy(err.as_bytes()).into_owned(),
544            };
545            Message::Text(text)
546        };
547        let mut guard = self.write.lock().await;
548        if let Some(write) = guard.as_mut() {
549            write
550                .send(ws_message)
551                .await
552                .map_err(|e| CopilotRequestError::Upstream(e.to_string()))?;
553        }
554        Ok(())
555    }
556
557    async fn close(&self) -> Result<(), CopilotRequestError> {
558        self.cancel.cancel();
559        let mut guard = self.write.lock().await;
560        if let Some(mut write) = guard.take() {
561            let _ = write.send(Message::Close(None)).await;
562            let _ = write.close().await;
563        }
564        Ok(())
565    }
566}
567
568fn apply_transform(
569    transform: &Option<WebSocketTransform>,
570    message: CopilotWebSocketMessage,
571) -> Option<CopilotWebSocketMessage> {
572    match transform {
573        Some(f) => f(message),
574        None => Some(message),
575    }
576}
577
578/// Mutable response state machine for a single exchange.
579#[derive(Default)]
580struct ResponseState {
581    started: bool,
582    finished: bool,
583}
584
585/// One intercepted request in flight.
586///
587/// Carries the request metadata plus the body byte stream the runtime feeds in
588/// via `httpRequestChunk` frames, and emits the handler's response straight back
589/// to the runtime through the generated `llmInference` server API — a single
590/// object the dispatcher owns and the handler drives.
591/// Request context populated when the matching `httpRequestStart` frame
592/// arrives. Held behind a `OnceLock` so the owning [`CopilotRequestExchange`]
593/// can be created bare by a body chunk that races ahead of its start frame.
594#[derive(Default)]
595struct RequestMeta {
596    session_id: Option<String>,
597    method: String,
598    url: String,
599    headers: HeaderMap,
600    transport: CopilotRequestTransport,
601}
602
603struct CopilotRequestExchange {
604    request_id: String,
605    meta: OnceLock<RequestMeta>,
606    cancel: CancellationToken,
607    client: Weak<ClientInner>,
608    /// Sender feeding the request body stream. Dropped (set to `None`) on `end`
609    /// or `cancel` to close the stream.
610    body_tx: Mutex<Option<mpsc::UnboundedSender<Vec<u8>>>>,
611    body_rx: AsyncMutex<mpsc::UnboundedReceiver<Vec<u8>>>,
612    state: Mutex<ResponseState>,
613}
614
615impl CopilotRequestExchange {
616    fn new(request_id: String, client: Weak<ClientInner>) -> Self {
617        let (body_tx, body_rx) = mpsc::unbounded_channel();
618        Self {
619            request_id,
620            meta: OnceLock::new(),
621            cancel: CancellationToken::new(),
622            client,
623            body_tx: Mutex::new(Some(body_tx)),
624            body_rx: AsyncMutex::new(body_rx),
625            state: Mutex::new(ResponseState::default()),
626        }
627    }
628
629    /// Fill in the request context once the matching start frame arrives.
630    fn set_context(&self, params: LlmInferenceHttpRequestStartRequest) {
631        let _ = self.meta.set(RequestMeta {
632            session_id: params.session_id.map(SessionId::into_inner),
633            method: params.method,
634            url: params.url,
635            headers: headers_from_wire(&params.headers),
636            transport: CopilotRequestTransport::from_wire(params.transport),
637        });
638    }
639
640    /// Request metadata. Always populated before the handler runs; the
641    /// defaulted fallback only guards the (contract-impossible) case of a body
642    /// chunk with no preceding start frame.
643    fn meta(&self) -> &RequestMeta {
644        self.meta.get_or_init(RequestMeta::default)
645    }
646
647    fn context(&self) -> CopilotRequestContext {
648        let meta = self.meta();
649        CopilotRequestContext {
650            request_id: self.request_id.clone(),
651            session_id: meta.session_id.clone(),
652            transport: meta.transport,
653            url: meta.url.clone(),
654            headers: meta.headers.clone(),
655            cancel: self.cancel.clone(),
656        }
657    }
658
659    fn client(&self) -> Result<Client, CopilotRequestError> {
660        self.client
661            .upgrade()
662            .map(Client::from_inner)
663            .ok_or(CopilotRequestError::ConnectionClosed)
664    }
665
666    fn request_id(&self) -> RequestId {
667        RequestId::new(self.request_id.clone())
668    }
669
670    // --- Request body feed (driven by the dispatcher as frames arrive) ---
671
672    fn push_chunk(&self, data: Vec<u8>) {
673        if let Some(tx) = self.body_tx.lock().as_ref() {
674            let _ = tx.send(data);
675        }
676    }
677
678    fn push_end(&self) {
679        *self.body_tx.lock() = None;
680    }
681
682    fn push_cancel(&self) {
683        self.cancel.cancel();
684        *self.body_tx.lock() = None;
685    }
686
687    async fn recv_body(&self) -> Option<Vec<u8>> {
688        self.body_rx.lock().await.recv().await
689    }
690
691    async fn drain_body(&self) -> Vec<u8> {
692        let mut buf = Vec::new();
693        let mut rx = self.body_rx.lock().await;
694        while let Some(frame) = rx.recv().await {
695            buf.extend_from_slice(&frame);
696        }
697        buf
698    }
699
700    // --- Response emit (driven by the handler). Strict state machine: ---
701    // start_response once -> 0..N write -> exactly one of
702    // end_response / error_response.
703
704    fn started(&self) -> bool {
705        self.state.lock().started
706    }
707
708    fn finished(&self) -> bool {
709        self.state.lock().finished
710    }
711
712    async fn start_response(
713        &self,
714        status: u16,
715        status_text: Option<String>,
716        headers: HeaderMap,
717    ) -> Result<(), CopilotRequestError> {
718        {
719            let mut state = self.state.lock();
720            if state.started {
721                return Err(CopilotRequestError::InvalidState(
722                    "response start() called twice".to_string(),
723                ));
724            }
725            if state.finished {
726                return Err(CopilotRequestError::InvalidState(
727                    "response already finished".to_string(),
728                ));
729            }
730            state.started = true;
731        }
732        let request = LlmInferenceHttpResponseStartRequest {
733            headers: headers_to_wire(&headers),
734            request_id: self.request_id(),
735            status: i64::from(status),
736            status_text,
737        };
738        self.client()?
739            .rpc()
740            .llm_inference()
741            .http_response_start(request)
742            .await?;
743        Ok(())
744    }
745
746    /// Start the WebSocket upgrade head (status 101) once, ignoring repeat
747    /// calls. The dispatcher emits it eagerly before pumping; later writes call
748    /// this as a harmless no-op backstop.
749    async fn ensure_ws_started(&self) -> Result<(), CopilotRequestError> {
750        if self.started() {
751            return Ok(());
752        }
753        self.start_response(101, None, HeaderMap::new()).await
754    }
755
756    async fn write_text(&self, text: &str) -> Result<(), CopilotRequestError> {
757        self.write(text.to_string(), false).await
758    }
759
760    async fn write_binary(&self, data: &[u8]) -> Result<(), CopilotRequestError> {
761        let encoded = base64::engine::general_purpose::STANDARD.encode(data);
762        self.write(encoded, true).await
763    }
764
765    async fn write(&self, data: String, binary: bool) -> Result<(), CopilotRequestError> {
766        {
767            let state = self.state.lock();
768            if !state.started {
769                return Err(CopilotRequestError::InvalidState(
770                    "response write called before start()".to_string(),
771                ));
772            }
773            if state.finished {
774                return Err(CopilotRequestError::InvalidState(
775                    "response write called after end()/error()".to_string(),
776                ));
777            }
778        }
779        let request = LlmInferenceHttpResponseChunkRequest {
780            binary: binary.then_some(true),
781            data,
782            end: Some(false),
783            error: None,
784            request_id: self.request_id(),
785        };
786        self.client()?
787            .rpc()
788            .llm_inference()
789            .http_response_chunk(request)
790            .await?;
791        Ok(())
792    }
793
794    async fn end_response(&self) -> Result<(), CopilotRequestError> {
795        {
796            let mut state = self.state.lock();
797            if state.finished {
798                return Ok(());
799            }
800            state.finished = true;
801        }
802        let request = LlmInferenceHttpResponseChunkRequest {
803            binary: None,
804            data: String::new(),
805            end: Some(true),
806            error: None,
807            request_id: self.request_id(),
808        };
809        self.client()?
810            .rpc()
811            .llm_inference()
812            .http_response_chunk(request)
813            .await?;
814        Ok(())
815    }
816
817    async fn error_response(
818        &self,
819        message: impl Into<String>,
820        code: Option<String>,
821    ) -> Result<(), CopilotRequestError> {
822        {
823            let mut state = self.state.lock();
824            if state.finished {
825                return Ok(());
826            }
827            state.finished = true;
828        }
829        let request = LlmInferenceHttpResponseChunkRequest {
830            binary: None,
831            data: String::new(),
832            end: Some(true),
833            error: Some(LlmInferenceHttpResponseChunkError {
834                code,
835                message: message.into(),
836            }),
837            request_id: self.request_id(),
838        };
839        self.client()?
840            .rpc()
841            .llm_inference()
842            .http_response_chunk(request)
843            .await?;
844        Ok(())
845    }
846}
847
848/// Drive one exchange through the registered handler, dispatching by transport.
849async fn drive_exchange(
850    exchange: &Arc<CopilotRequestExchange>,
851    handler: &Arc<dyn CopilotRequestHandler>,
852) -> Result<(), CopilotRequestError> {
853    let ctx = exchange.context();
854    let meta = exchange.meta();
855    match meta.transport {
856        CopilotRequestTransport::Http => {
857            let body = exchange.drain_body().await;
858            let request = CopilotHttpRequest {
859                method: meta.method.clone(),
860                url: meta.url.clone(),
861                headers: meta.headers.clone(),
862                body,
863                cancel: ctx.cancel.clone(),
864            };
865            let response = handler.send_request(request, &ctx).await?;
866            stream_http_response(response, exchange, &ctx.cancel).await
867        }
868        CopilotRequestTransport::WebSocket => {
869            // The runtime blocks the WebSocket connect until it receives the 101
870            // response head (the upgrade acknowledgement) and only then forwards
871            // inbound messages as request-body chunks. Emit it eagerly here —
872            // waiting for the first upstream message would deadlock, since the
873            // upstream stays silent until it receives a request message the
874            // runtime won't send before the upgrade completes.
875            exchange.ensure_ws_started().await?;
876            let response = CopilotWebSocketResponse::new(exchange.clone());
877            let ws = handler.open_websocket(&ctx, response).await?;
878            let result = pump_websocket_requests(ws.as_ref(), exchange, &ctx.cancel).await;
879            let _ = ws.close().await;
880            match result {
881                Ok(()) => exchange.end_response().await,
882                Err(err) if ctx.cancel.is_cancelled() => {
883                    exchange
884                        .error_response(
885                            "Request cancelled by runtime",
886                            Some("cancelled".to_string()),
887                        )
888                        .await?;
889                    let _ = err;
890                    Ok(())
891                }
892                Err(err) => Err(err),
893            }
894        }
895    }
896}
897
898/// Stream an HTTP response into the runtime, honouring cancellation.
899async fn stream_http_response(
900    response: CopilotHttpResponse,
901    exchange: &CopilotRequestExchange,
902    cancel: &CancellationToken,
903) -> Result<(), CopilotRequestError> {
904    exchange
905        .start_response(response.status, response.status_text, response.headers)
906        .await?;
907
908    let mut body = response.body;
909    loop {
910        tokio::select! {
911            _ = cancel.cancelled() => {
912                return exchange
913                    .error_response("Request cancelled by runtime", Some("cancelled".to_string()))
914                    .await;
915            }
916            next = body.next() => match next {
917                Some(Ok(chunk)) => {
918                    for piece in chunk.chunks(32 * 1024) {
919                        exchange.write_binary(piece).await?;
920                    }
921                }
922                Some(Err(e)) => {
923                    return exchange.error_response(e.to_string(), None).await;
924                }
925                None => break,
926            }
927        }
928    }
929    exchange.end_response().await
930}
931
932/// Forward runtime→upstream WebSocket messages until the runtime closes its side
933/// or cancels.
934async fn pump_websocket_requests(
935    handler: &dyn CopilotWebSocketHandler,
936    exchange: &CopilotRequestExchange,
937    cancel: &CancellationToken,
938) -> Result<(), CopilotRequestError> {
939    loop {
940        tokio::select! {
941            _ = cancel.cancelled() => {
942                return Err(CopilotRequestError::message("Request cancelled by runtime"));
943            }
944            frame = exchange.recv_body() => match frame {
945                Some(data) => {
946                    handler
947                        .send_request_message(CopilotWebSocketMessage { data, binary: false })
948                        .await?;
949                }
950                None => return Ok(()),
951            }
952        }
953    }
954}
955
956/// Drive the exchange's response to a terminal state once the handler returns,
957/// covering handlers that error, get cancelled, or forget to finalize.
958async fn finalize_exchange(
959    exchange: &CopilotRequestExchange,
960    result: Result<(), CopilotRequestError>,
961) {
962    match result {
963        Ok(()) => {
964            if !exchange.finished() {
965                fail_via_response(
966                    exchange,
967                    502,
968                    "Copilot request handler returned without finalising the response".to_string(),
969                )
970                .await;
971            }
972        }
973        Err(err) => {
974            if exchange.finished() {
975                return;
976            }
977            if exchange.cancel.is_cancelled() {
978                if !exchange.started() {
979                    let _ = exchange.start_response(499, None, HeaderMap::new()).await;
980                }
981                let _ = exchange
982                    .error_response(
983                        "Request cancelled by runtime",
984                        Some("cancelled".to_string()),
985                    )
986                    .await;
987            } else {
988                fail_via_response(exchange, 502, err.to_string()).await;
989            }
990        }
991    }
992}
993
994async fn fail_via_response(exchange: &CopilotRequestExchange, status: u16, message: String) {
995    if !exchange.started() {
996        let _ = exchange
997            .start_response(status, None, HeaderMap::new())
998            .await;
999    }
1000    let _ = exchange.error_response(message, None).await;
1001}
1002
1003/// Routes inbound `llmInference.*` requests to the registered handler,
1004/// reassembling each request's streaming body and acking every frame.
1005pub(crate) struct CopilotRequestDispatcher {
1006    handler: Arc<dyn CopilotRequestHandler>,
1007    client: OnceLock<Weak<ClientInner>>,
1008    pending: Mutex<HashMap<String, Arc<CopilotRequestExchange>>>,
1009}
1010
1011impl CopilotRequestDispatcher {
1012    pub(crate) fn new(handler: Arc<dyn CopilotRequestHandler>) -> Self {
1013        Self {
1014            handler,
1015            client: OnceLock::new(),
1016            pending: Mutex::new(HashMap::new()),
1017        }
1018    }
1019
1020    pub(crate) fn set_client(&self, client: Weak<ClientInner>) {
1021        let _ = self.client.set(client);
1022    }
1023
1024    fn client(&self) -> Option<Client> {
1025        self.client
1026            .get()
1027            .and_then(Weak::upgrade)
1028            .map(Client::from_inner)
1029    }
1030
1031    fn client_weak(&self) -> Weak<ClientInner> {
1032        self.client.get().cloned().unwrap_or_else(Weak::new)
1033    }
1034
1035    pub(crate) async fn dispatch(self: &Arc<Self>, request: JsonRpcRequest) {
1036        match request.method.as_str() {
1037            METHOD_HTTP_REQUEST_START => self.handle_start(request).await,
1038            METHOD_HTTP_REQUEST_CHUNK => self.handle_chunk(request).await,
1039            other => {
1040                warn!(method = other, "unknown llmInference request method");
1041                self.send_error(request.id, "unknown llmInference method")
1042                    .await;
1043            }
1044        }
1045    }
1046
1047    fn get_or_create_exchange(&self, request_id: String) -> Arc<CopilotRequestExchange> {
1048        // The runtime dispatches httpRequestStart and httpRequestChunk frames
1049        // independently. get-or-create keeps the adapter correct regardless of
1050        // arrival order: a body chunk (including the terminal end frame) that
1051        // races ahead of its start frame is buffered into the same exchange
1052        // rather than dropped, which would otherwise hang the body drain.
1053        self.pending
1054            .lock()
1055            .entry(request_id.clone())
1056            .or_insert_with(|| {
1057                Arc::new(CopilotRequestExchange::new(request_id, self.client_weak()))
1058            })
1059            .clone()
1060    }
1061
1062    async fn handle_start(self: &Arc<Self>, request: JsonRpcRequest) {
1063        let id = request.id;
1064        let Some(params) = parse_params::<LlmInferenceHttpRequestStartRequest>(&request) else {
1065            self.send_error(id, "invalid llmInference.httpRequestStart params")
1066                .await;
1067            return;
1068        };
1069
1070        // Adopt any exchange a racing chunk already created — with its buffered
1071        // body — rather than dropping those frames.
1072        let request_id = params.request_id.clone().into_inner();
1073        let exchange = self.get_or_create_exchange(request_id.clone());
1074        exchange.set_context(params);
1075
1076        let handler = self.handler.clone();
1077        let dispatcher = Arc::clone(self);
1078        let exchange_for_task = exchange.clone();
1079        tokio::spawn(async move {
1080            let result = drive_exchange(&exchange_for_task, &handler).await;
1081            finalize_exchange(&exchange_for_task, result).await;
1082            dispatcher.remove_pending(&request_id);
1083        });
1084
1085        self.ack(id).await;
1086    }
1087
1088    async fn handle_chunk(&self, request: JsonRpcRequest) {
1089        let id = request.id;
1090        let Some(params) = parse_params::<LlmInferenceHttpRequestChunkRequest>(&request) else {
1091            self.send_error(id, "invalid llmInference.httpRequestChunk params")
1092                .await;
1093            return;
1094        };
1095
1096        // May arrive before the matching start frame; get-or-create so the body
1097        // is buffered, never lost.
1098        let exchange = self.get_or_create_exchange(params.request_id.to_string());
1099        apply_chunk(&exchange, &params);
1100
1101        self.ack(id).await;
1102    }
1103
1104    fn remove_pending(&self, request_id: &str) {
1105        self.pending.lock().remove(request_id);
1106    }
1107
1108    async fn ack(&self, id: u64) {
1109        let Some(client) = self.client() else {
1110            return;
1111        };
1112        let _ = client
1113            .send_response(&JsonRpcResponse {
1114                jsonrpc: "2.0".to_string(),
1115                id,
1116                result: Some(serde_json::json!({})),
1117                error: None,
1118            })
1119            .await;
1120    }
1121
1122    async fn send_error(&self, id: u64, message: &str) {
1123        let Some(client) = self.client() else {
1124            return;
1125        };
1126        let _ = client
1127            .send_response(&JsonRpcResponse {
1128                jsonrpc: "2.0".to_string(),
1129                id,
1130                result: None,
1131                error: Some(crate::JsonRpcError {
1132                    code: error_codes::INTERNAL_ERROR,
1133                    message: message.to_string(),
1134                    data: None,
1135                }),
1136            })
1137            .await;
1138    }
1139}
1140
1141/// Apply one body chunk to a pending request: route data into the body stream,
1142/// or terminate it on `end` / `cancel`.
1143fn apply_chunk(exchange: &CopilotRequestExchange, params: &LlmInferenceHttpRequestChunkRequest) {
1144    if params.cancel == Some(true) {
1145        exchange.push_cancel();
1146        return;
1147    }
1148
1149    if !params.data.is_empty() {
1150        let decoded = if params.binary == Some(true) {
1151            match base64::engine::general_purpose::STANDARD.decode(params.data.as_bytes()) {
1152                Ok(bytes) => bytes,
1153                Err(e) => {
1154                    warn!(error = %e, "failed to decode base64 llmInference body chunk");
1155                    return;
1156                }
1157            }
1158        } else {
1159            params.data.clone().into_bytes()
1160        };
1161        exchange.push_chunk(decoded);
1162    }
1163
1164    if params.end == Some(true) {
1165        exchange.push_end();
1166    }
1167}
1168
1169fn parse_params<T: serde::de::DeserializeOwned>(request: &JsonRpcRequest) -> Option<T> {
1170    request
1171        .params
1172        .as_ref()
1173        .and_then(|p| serde_json::from_value(p.clone()).ok())
1174}
1175
1176/// Convert a wire header map into an [`http::HeaderMap`], skipping any entry the
1177/// `http` crate rejects.
1178fn headers_from_wire(wire: &HashMap<String, Vec<String>>) -> HeaderMap {
1179    let mut headers = HeaderMap::new();
1180    for (name, values) in wire {
1181        let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
1182            continue;
1183        };
1184        for value in values {
1185            let Ok(header_value) = HeaderValue::from_str(value) else {
1186                continue;
1187            };
1188            headers.append(header_name.clone(), header_value);
1189        }
1190    }
1191    headers
1192}
1193
1194/// Convert an [`http::HeaderMap`] into the wire header map, dropping values that
1195/// are not valid UTF-8.
1196fn headers_to_wire(headers: &HeaderMap) -> HashMap<String, Vec<String>> {
1197    let mut wire: HashMap<String, Vec<String>> = HashMap::new();
1198    for (name, value) in headers {
1199        let Ok(value) = value.to_str() else {
1200            continue;
1201        };
1202        wire.entry(name.as_str().to_string())
1203            .or_default()
1204            .push(value.to_string());
1205    }
1206    wire
1207}