1use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::{Arc, LazyLock, OnceLock, Weak};
27
28use async_trait::async_trait;
29use base64::Engine;
30use bytes::Bytes;
31use futures_util::{SinkExt, Stream, StreamExt};
32use http::HeaderMap;
33use http::header::{HeaderName, HeaderValue};
34use parking_lot::Mutex;
35use tokio::net::TcpStream;
36use tokio::sync::{Mutex as AsyncMutex, mpsc};
37use tokio_tungstenite::tungstenite::Message;
38use tokio_tungstenite::tungstenite::client::IntoClientRequest;
39use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
40use tokio_util::sync::CancellationToken;
41use tracing::warn;
42
43use crate::generated::api_types::{
44 LlmInferenceHttpRequestChunkRequest, LlmInferenceHttpRequestStartRequest,
45 LlmInferenceHttpRequestStartTransport, LlmInferenceHttpResponseChunkError,
46 LlmInferenceHttpResponseChunkRequest, LlmInferenceHttpResponseStartRequest,
47};
48use crate::{
49 Client, ClientInner, JsonRpcRequest, JsonRpcResponse, RequestId, SessionId, error_codes,
50};
51
52const METHOD_HTTP_REQUEST_START: &str = "llmInference.httpRequestStart";
53const METHOD_HTTP_REQUEST_CHUNK: &str = "llmInference.httpRequestChunk";
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57pub enum CopilotRequestTransport {
58 #[default]
60 Http,
61 WebSocket,
64}
65
66impl CopilotRequestTransport {
67 fn from_wire(value: Option<LlmInferenceHttpRequestStartTransport>) -> Self {
68 match value {
69 Some(LlmInferenceHttpRequestStartTransport::Websocket) => Self::WebSocket,
70 _ => Self::Http,
71 }
72 }
73}
74
75#[derive(Debug)]
77#[non_exhaustive]
78pub enum CopilotRequestError {
79 ConnectionClosed,
81
82 InvalidState(String),
85
86 Upstream(String),
88
89 Handler(String),
91
92 Rpc(crate::Error),
94}
95
96impl CopilotRequestError {
97 pub fn message(message: impl Into<String>) -> Self {
100 Self::Handler(message.into())
101 }
102}
103
104impl std::fmt::Display for CopilotRequestError {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 Self::ConnectionClosed => {
108 f.write_str("Copilot request response used after RPC connection closed")
109 }
110 Self::InvalidState(message) | Self::Upstream(message) | Self::Handler(message) => {
111 f.write_str(message)
112 }
113 Self::Rpc(err) => write!(f, "{err}"),
114 }
115 }
116}
117
118impl std::error::Error for CopilotRequestError {
119 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
120 match self {
121 Self::Rpc(err) => Some(err),
122 _ => None,
123 }
124 }
125}
126
127impl From<crate::Error> for CopilotRequestError {
128 fn from(err: crate::Error) -> Self {
129 Self::Rpc(err)
130 }
131}
132
133#[derive(Clone)]
136#[non_exhaustive]
137pub struct CopilotRequestContext {
138 pub request_id: String,
140 pub session_id: Option<String>,
143 pub transport: CopilotRequestTransport,
145 pub url: String,
147 pub headers: HeaderMap,
149 pub cancel: CancellationToken,
151}
152
153pub type CopilotHttpResponseBody =
155 Pin<Box<dyn Stream<Item = Result<Bytes, CopilotRequestError>> + Send>>;
156
157#[non_exhaustive]
159pub struct CopilotHttpRequest {
160 pub method: String,
162 pub url: String,
164 pub headers: HeaderMap,
166 pub body: Vec<u8>,
168 pub cancel: CancellationToken,
170}
171
172#[non_exhaustive]
174pub struct CopilotHttpResponse {
175 pub status: u16,
177 pub status_text: Option<String>,
179 pub headers: HeaderMap,
181 pub body: CopilotHttpResponseBody,
183}
184
185impl CopilotHttpResponse {
186 pub fn new(
188 status: u16,
189 status_text: Option<String>,
190 headers: HeaderMap,
191 body: CopilotHttpResponseBody,
192 ) -> Self {
193 Self {
194 status,
195 status_text,
196 headers,
197 body,
198 }
199 }
200}
201
202#[derive(Clone)]
204pub struct CopilotWebSocketMessage {
205 pub data: Vec<u8>,
207 pub binary: bool,
209}
210
211impl CopilotWebSocketMessage {
212 pub fn from_text(data: impl Into<String>) -> Self {
215 Self {
216 data: data.into().into_bytes(),
217 binary: false,
218 }
219 }
220}
221
222#[derive(Clone)]
225pub struct CopilotWebSocketResponse {
226 exchange: Arc<CopilotRequestExchange>,
227}
228
229impl CopilotWebSocketResponse {
230 fn new(exchange: Arc<CopilotRequestExchange>) -> Self {
231 Self { exchange }
232 }
233
234 pub async fn send_message(
236 &self,
237 message: CopilotWebSocketMessage,
238 ) -> Result<(), CopilotRequestError> {
239 self.exchange.ensure_ws_started().await?;
240 if message.binary {
241 self.exchange.write_binary(&message.data).await
242 } else {
243 let text = String::from_utf8_lossy(&message.data);
244 self.exchange.write_text(&text).await
245 }
246 }
247
248 pub async fn close(&self) -> Result<(), CopilotRequestError> {
250 self.exchange.end_response().await
251 }
252
253 async fn fail(
254 &self,
255 message: impl Into<String>,
256 code: Option<String>,
257 ) -> Result<(), CopilotRequestError> {
258 self.exchange.error_response(message, code).await
259 }
260}
261
262#[async_trait]
266pub trait CopilotWebSocketHandler: Send + Sync {
267 async fn send_request_message(
269 &self,
270 message: CopilotWebSocketMessage,
271 ) -> Result<(), CopilotRequestError>;
272
273 async fn close(&self) -> Result<(), CopilotRequestError>;
275}
276
277#[async_trait]
283pub trait CopilotRequestHandler: Send + Sync + 'static {
284 async fn send_request(
288 &self,
289 request: CopilotHttpRequest,
290 _ctx: &CopilotRequestContext,
291 ) -> Result<CopilotHttpResponse, CopilotRequestError> {
292 forward_http(request).await
293 }
294
295 async fn open_websocket(
307 &self,
308 ctx: &CopilotRequestContext,
309 response: CopilotWebSocketResponse,
310 ) -> Result<Box<dyn CopilotWebSocketHandler>, CopilotRequestError> {
311 let handler = CopilotWebSocketForwarder::builder(ctx.url.clone(), ctx.headers.clone())
312 .connect(response)
313 .await?;
314 Ok(Box::new(handler))
315 }
316}
317
318#[async_trait]
321impl<H: CopilotRequestHandler> CopilotRequestHandler for Arc<H> {
322 async fn send_request(
323 &self,
324 request: CopilotHttpRequest,
325 ctx: &CopilotRequestContext,
326 ) -> Result<CopilotHttpResponse, CopilotRequestError> {
327 (**self).send_request(request, ctx).await
328 }
329
330 async fn open_websocket(
331 &self,
332 ctx: &CopilotRequestContext,
333 response: CopilotWebSocketResponse,
334 ) -> Result<Box<dyn CopilotWebSocketHandler>, CopilotRequestError> {
335 (**self).open_websocket(ctx, response).await
336 }
337}
338const FORBIDDEN_HEADERS: &[&str] = &[
340 "host",
341 "connection",
342 "content-length",
343 "transfer-encoding",
344 "keep-alive",
345 "upgrade",
346 "proxy-connection",
347 "te",
348 "trailer",
349];
350
351fn is_forbidden_header(name: &HeaderName) -> bool {
352 let name = name.as_str();
353 FORBIDDEN_HEADERS.contains(&name) || name.starts_with("sec-websocket")
354}
355
356fn strip_forbidden_headers(headers: &mut HeaderMap) {
358 let forbidden: Vec<HeaderName> = headers
359 .keys()
360 .filter(|name| is_forbidden_header(name))
361 .cloned()
362 .collect();
363 for name in forbidden {
364 headers.remove(&name);
365 }
366}
367
368static SHARED_HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
369 reqwest::Client::builder()
370 .redirect(reqwest::redirect::Policy::none())
371 .build()
372 .expect("default reqwest client must build")
373});
374
375pub async fn forward_http(
380 request: CopilotHttpRequest,
381) -> Result<CopilotHttpResponse, CopilotRequestError> {
382 let method = reqwest::Method::from_bytes(request.method.as_bytes())
383 .map_err(|e| CopilotRequestError::InvalidState(format!("invalid HTTP method: {e}")))?;
384
385 let mut headers = request.headers;
386 strip_forbidden_headers(&mut headers);
387
388 let mut builder = SHARED_HTTP_CLIENT
389 .request(method, &request.url)
390 .headers(headers);
391 if !request.body.is_empty() {
392 builder = builder.body(request.body);
393 }
394
395 let response = tokio::select! {
396 _ = request.cancel.cancelled() => {
397 return Err(CopilotRequestError::message("Request cancelled by runtime"));
398 }
399 result = builder.send() => result.map_err(|e| CopilotRequestError::Upstream(e.to_string()))?,
400 };
401
402 let status = response.status().as_u16();
403 let status_text = response.status().canonical_reason().map(str::to_string);
404 let headers = response.headers().clone();
405 let body = response
406 .bytes_stream()
407 .map(|item| item.map_err(|e| CopilotRequestError::Upstream(e.to_string())));
408
409 Ok(CopilotHttpResponse {
410 status,
411 status_text,
412 headers,
413 body: Box::pin(body),
414 })
415}
416
417type UpstreamWrite =
418 futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
419
420pub type WebSocketTransform =
422 Arc<dyn Fn(CopilotWebSocketMessage) -> Option<CopilotWebSocketMessage> + Send + Sync>;
423
424pub struct CopilotWebSocketForwarderBuilder {
426 url: String,
427 headers: HeaderMap,
428 on_send_request_message: Option<WebSocketTransform>,
429 on_send_response_message: Option<WebSocketTransform>,
430}
431
432impl CopilotWebSocketForwarderBuilder {
433 pub fn on_send_request_message(mut self, transform: WebSocketTransform) -> Self {
435 self.on_send_request_message = Some(transform);
436 self
437 }
438
439 pub fn on_send_response_message(mut self, transform: WebSocketTransform) -> Self {
441 self.on_send_response_message = Some(transform);
442 self
443 }
444
445 pub async fn connect(
448 self,
449 response: CopilotWebSocketResponse,
450 ) -> Result<CopilotWebSocketForwarder, CopilotRequestError> {
451 let mut request =
452 self.url.as_str().into_client_request().map_err(|e| {
453 CopilotRequestError::Upstream(format!("invalid websocket url: {e}"))
454 })?;
455 for (name, value) in &self.headers {
456 if is_forbidden_header(name) {
457 continue;
458 }
459 request.headers_mut().append(name.clone(), value.clone());
460 }
461
462 let (stream, _) = connect_async(request)
463 .await
464 .map_err(|e| CopilotRequestError::Upstream(format!("websocket connect failed: {e}")))?;
465 let (write, mut read) = stream.split();
466
467 let cancel = CancellationToken::new();
468 let loop_cancel = cancel.clone();
469 let on_response = self.on_send_response_message.clone();
470 tokio::spawn(async move {
471 loop {
472 tokio::select! {
473 _ = loop_cancel.cancelled() => break,
474 msg = read.next() => match msg {
475 Some(Ok(Message::Text(text))) => {
476 let message = CopilotWebSocketMessage::from_text(text);
477 if let Some(out) = apply_transform(&on_response, message) {
478 let _ = response.send_message(out).await;
479 }
480 }
481 Some(Ok(Message::Binary(data))) => {
482 let message = CopilotWebSocketMessage { data, binary: true };
483 if let Some(out) = apply_transform(&on_response, message) {
484 let _ = response.send_message(out).await;
485 }
486 }
487 Some(Ok(Message::Close(_))) | None => break,
488 Some(Ok(_)) => continue,
489 Some(Err(e)) => {
490 let _ = response.fail(e.to_string(), None).await;
491 return;
492 }
493 }
494 }
495 }
496 let _ = response.close().await;
497 });
498
499 Ok(CopilotWebSocketForwarder {
500 write: AsyncMutex::new(Some(write)),
501 on_send_request_message: self.on_send_request_message,
502 cancel,
503 })
504 }
505}
506
507pub struct CopilotWebSocketForwarder {
511 write: AsyncMutex<Option<UpstreamWrite>>,
512 on_send_request_message: Option<WebSocketTransform>,
513 cancel: CancellationToken,
514}
515
516impl CopilotWebSocketForwarder {
517 pub fn builder(url: String, headers: HeaderMap) -> CopilotWebSocketForwarderBuilder {
520 CopilotWebSocketForwarderBuilder {
521 url,
522 headers,
523 on_send_request_message: None,
524 on_send_response_message: None,
525 }
526 }
527}
528
529#[async_trait]
530impl CopilotWebSocketHandler for CopilotWebSocketForwarder {
531 async fn send_request_message(
532 &self,
533 message: CopilotWebSocketMessage,
534 ) -> Result<(), CopilotRequestError> {
535 let Some(message) = apply_transform(&self.on_send_request_message, message) else {
536 return Ok(());
537 };
538 let ws_message = if message.binary {
539 Message::Binary(message.data)
540 } else {
541 let text = match String::from_utf8(message.data) {
542 Ok(text) => text,
543 Err(err) => String::from_utf8_lossy(err.as_bytes()).into_owned(),
544 };
545 Message::Text(text)
546 };
547 let mut guard = self.write.lock().await;
548 if let Some(write) = guard.as_mut() {
549 write
550 .send(ws_message)
551 .await
552 .map_err(|e| CopilotRequestError::Upstream(e.to_string()))?;
553 }
554 Ok(())
555 }
556
557 async fn close(&self) -> Result<(), CopilotRequestError> {
558 self.cancel.cancel();
559 let mut guard = self.write.lock().await;
560 if let Some(mut write) = guard.take() {
561 let _ = write.send(Message::Close(None)).await;
562 let _ = write.close().await;
563 }
564 Ok(())
565 }
566}
567
568fn apply_transform(
569 transform: &Option<WebSocketTransform>,
570 message: CopilotWebSocketMessage,
571) -> Option<CopilotWebSocketMessage> {
572 match transform {
573 Some(f) => f(message),
574 None => Some(message),
575 }
576}
577
578#[derive(Default)]
580struct ResponseState {
581 started: bool,
582 finished: bool,
583}
584
585#[derive(Default)]
595struct RequestMeta {
596 session_id: Option<String>,
597 method: String,
598 url: String,
599 headers: HeaderMap,
600 transport: CopilotRequestTransport,
601}
602
603struct CopilotRequestExchange {
604 request_id: String,
605 meta: OnceLock<RequestMeta>,
606 cancel: CancellationToken,
607 client: Weak<ClientInner>,
608 body_tx: Mutex<Option<mpsc::UnboundedSender<Vec<u8>>>>,
611 body_rx: AsyncMutex<mpsc::UnboundedReceiver<Vec<u8>>>,
612 state: Mutex<ResponseState>,
613}
614
615impl CopilotRequestExchange {
616 fn new(request_id: String, client: Weak<ClientInner>) -> Self {
617 let (body_tx, body_rx) = mpsc::unbounded_channel();
618 Self {
619 request_id,
620 meta: OnceLock::new(),
621 cancel: CancellationToken::new(),
622 client,
623 body_tx: Mutex::new(Some(body_tx)),
624 body_rx: AsyncMutex::new(body_rx),
625 state: Mutex::new(ResponseState::default()),
626 }
627 }
628
629 fn set_context(&self, params: LlmInferenceHttpRequestStartRequest) {
631 let _ = self.meta.set(RequestMeta {
632 session_id: params.session_id.map(SessionId::into_inner),
633 method: params.method,
634 url: params.url,
635 headers: headers_from_wire(¶ms.headers),
636 transport: CopilotRequestTransport::from_wire(params.transport),
637 });
638 }
639
640 fn meta(&self) -> &RequestMeta {
644 self.meta.get_or_init(RequestMeta::default)
645 }
646
647 fn context(&self) -> CopilotRequestContext {
648 let meta = self.meta();
649 CopilotRequestContext {
650 request_id: self.request_id.clone(),
651 session_id: meta.session_id.clone(),
652 transport: meta.transport,
653 url: meta.url.clone(),
654 headers: meta.headers.clone(),
655 cancel: self.cancel.clone(),
656 }
657 }
658
659 fn client(&self) -> Result<Client, CopilotRequestError> {
660 self.client
661 .upgrade()
662 .map(Client::from_inner)
663 .ok_or(CopilotRequestError::ConnectionClosed)
664 }
665
666 fn request_id(&self) -> RequestId {
667 RequestId::new(self.request_id.clone())
668 }
669
670 fn push_chunk(&self, data: Vec<u8>) {
673 if let Some(tx) = self.body_tx.lock().as_ref() {
674 let _ = tx.send(data);
675 }
676 }
677
678 fn push_end(&self) {
679 *self.body_tx.lock() = None;
680 }
681
682 fn push_cancel(&self) {
683 self.cancel.cancel();
684 *self.body_tx.lock() = None;
685 }
686
687 async fn recv_body(&self) -> Option<Vec<u8>> {
688 self.body_rx.lock().await.recv().await
689 }
690
691 async fn drain_body(&self) -> Vec<u8> {
692 let mut buf = Vec::new();
693 let mut rx = self.body_rx.lock().await;
694 while let Some(frame) = rx.recv().await {
695 buf.extend_from_slice(&frame);
696 }
697 buf
698 }
699
700 fn started(&self) -> bool {
705 self.state.lock().started
706 }
707
708 fn finished(&self) -> bool {
709 self.state.lock().finished
710 }
711
712 async fn start_response(
713 &self,
714 status: u16,
715 status_text: Option<String>,
716 headers: HeaderMap,
717 ) -> Result<(), CopilotRequestError> {
718 {
719 let mut state = self.state.lock();
720 if state.started {
721 return Err(CopilotRequestError::InvalidState(
722 "response start() called twice".to_string(),
723 ));
724 }
725 if state.finished {
726 return Err(CopilotRequestError::InvalidState(
727 "response already finished".to_string(),
728 ));
729 }
730 state.started = true;
731 }
732 let request = LlmInferenceHttpResponseStartRequest {
733 headers: headers_to_wire(&headers),
734 request_id: self.request_id(),
735 status: i64::from(status),
736 status_text,
737 };
738 self.client()?
739 .rpc()
740 .llm_inference()
741 .http_response_start(request)
742 .await?;
743 Ok(())
744 }
745
746 async fn ensure_ws_started(&self) -> Result<(), CopilotRequestError> {
750 if self.started() {
751 return Ok(());
752 }
753 self.start_response(101, None, HeaderMap::new()).await
754 }
755
756 async fn write_text(&self, text: &str) -> Result<(), CopilotRequestError> {
757 self.write(text.to_string(), false).await
758 }
759
760 async fn write_binary(&self, data: &[u8]) -> Result<(), CopilotRequestError> {
761 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
762 self.write(encoded, true).await
763 }
764
765 async fn write(&self, data: String, binary: bool) -> Result<(), CopilotRequestError> {
766 {
767 let state = self.state.lock();
768 if !state.started {
769 return Err(CopilotRequestError::InvalidState(
770 "response write called before start()".to_string(),
771 ));
772 }
773 if state.finished {
774 return Err(CopilotRequestError::InvalidState(
775 "response write called after end()/error()".to_string(),
776 ));
777 }
778 }
779 let request = LlmInferenceHttpResponseChunkRequest {
780 binary: binary.then_some(true),
781 data,
782 end: Some(false),
783 error: None,
784 request_id: self.request_id(),
785 };
786 self.client()?
787 .rpc()
788 .llm_inference()
789 .http_response_chunk(request)
790 .await?;
791 Ok(())
792 }
793
794 async fn end_response(&self) -> Result<(), CopilotRequestError> {
795 {
796 let mut state = self.state.lock();
797 if state.finished {
798 return Ok(());
799 }
800 state.finished = true;
801 }
802 let request = LlmInferenceHttpResponseChunkRequest {
803 binary: None,
804 data: String::new(),
805 end: Some(true),
806 error: None,
807 request_id: self.request_id(),
808 };
809 self.client()?
810 .rpc()
811 .llm_inference()
812 .http_response_chunk(request)
813 .await?;
814 Ok(())
815 }
816
817 async fn error_response(
818 &self,
819 message: impl Into<String>,
820 code: Option<String>,
821 ) -> Result<(), CopilotRequestError> {
822 {
823 let mut state = self.state.lock();
824 if state.finished {
825 return Ok(());
826 }
827 state.finished = true;
828 }
829 let request = LlmInferenceHttpResponseChunkRequest {
830 binary: None,
831 data: String::new(),
832 end: Some(true),
833 error: Some(LlmInferenceHttpResponseChunkError {
834 code,
835 message: message.into(),
836 }),
837 request_id: self.request_id(),
838 };
839 self.client()?
840 .rpc()
841 .llm_inference()
842 .http_response_chunk(request)
843 .await?;
844 Ok(())
845 }
846}
847
848async fn drive_exchange(
850 exchange: &Arc<CopilotRequestExchange>,
851 handler: &Arc<dyn CopilotRequestHandler>,
852) -> Result<(), CopilotRequestError> {
853 let ctx = exchange.context();
854 let meta = exchange.meta();
855 match meta.transport {
856 CopilotRequestTransport::Http => {
857 let body = exchange.drain_body().await;
858 let request = CopilotHttpRequest {
859 method: meta.method.clone(),
860 url: meta.url.clone(),
861 headers: meta.headers.clone(),
862 body,
863 cancel: ctx.cancel.clone(),
864 };
865 let response = handler.send_request(request, &ctx).await?;
866 stream_http_response(response, exchange, &ctx.cancel).await
867 }
868 CopilotRequestTransport::WebSocket => {
869 exchange.ensure_ws_started().await?;
876 let response = CopilotWebSocketResponse::new(exchange.clone());
877 let ws = handler.open_websocket(&ctx, response).await?;
878 let result = pump_websocket_requests(ws.as_ref(), exchange, &ctx.cancel).await;
879 let _ = ws.close().await;
880 match result {
881 Ok(()) => exchange.end_response().await,
882 Err(err) if ctx.cancel.is_cancelled() => {
883 exchange
884 .error_response(
885 "Request cancelled by runtime",
886 Some("cancelled".to_string()),
887 )
888 .await?;
889 let _ = err;
890 Ok(())
891 }
892 Err(err) => Err(err),
893 }
894 }
895 }
896}
897
898async fn stream_http_response(
900 response: CopilotHttpResponse,
901 exchange: &CopilotRequestExchange,
902 cancel: &CancellationToken,
903) -> Result<(), CopilotRequestError> {
904 exchange
905 .start_response(response.status, response.status_text, response.headers)
906 .await?;
907
908 let mut body = response.body;
909 loop {
910 tokio::select! {
911 _ = cancel.cancelled() => {
912 return exchange
913 .error_response("Request cancelled by runtime", Some("cancelled".to_string()))
914 .await;
915 }
916 next = body.next() => match next {
917 Some(Ok(chunk)) => {
918 for piece in chunk.chunks(32 * 1024) {
919 exchange.write_binary(piece).await?;
920 }
921 }
922 Some(Err(e)) => {
923 return exchange.error_response(e.to_string(), None).await;
924 }
925 None => break,
926 }
927 }
928 }
929 exchange.end_response().await
930}
931
932async fn pump_websocket_requests(
935 handler: &dyn CopilotWebSocketHandler,
936 exchange: &CopilotRequestExchange,
937 cancel: &CancellationToken,
938) -> Result<(), CopilotRequestError> {
939 loop {
940 tokio::select! {
941 _ = cancel.cancelled() => {
942 return Err(CopilotRequestError::message("Request cancelled by runtime"));
943 }
944 frame = exchange.recv_body() => match frame {
945 Some(data) => {
946 handler
947 .send_request_message(CopilotWebSocketMessage { data, binary: false })
948 .await?;
949 }
950 None => return Ok(()),
951 }
952 }
953 }
954}
955
956async fn finalize_exchange(
959 exchange: &CopilotRequestExchange,
960 result: Result<(), CopilotRequestError>,
961) {
962 match result {
963 Ok(()) => {
964 if !exchange.finished() {
965 fail_via_response(
966 exchange,
967 502,
968 "Copilot request handler returned without finalising the response".to_string(),
969 )
970 .await;
971 }
972 }
973 Err(err) => {
974 if exchange.finished() {
975 return;
976 }
977 if exchange.cancel.is_cancelled() {
978 if !exchange.started() {
979 let _ = exchange.start_response(499, None, HeaderMap::new()).await;
980 }
981 let _ = exchange
982 .error_response(
983 "Request cancelled by runtime",
984 Some("cancelled".to_string()),
985 )
986 .await;
987 } else {
988 fail_via_response(exchange, 502, err.to_string()).await;
989 }
990 }
991 }
992}
993
994async fn fail_via_response(exchange: &CopilotRequestExchange, status: u16, message: String) {
995 if !exchange.started() {
996 let _ = exchange
997 .start_response(status, None, HeaderMap::new())
998 .await;
999 }
1000 let _ = exchange.error_response(message, None).await;
1001}
1002
1003pub(crate) struct CopilotRequestDispatcher {
1006 handler: Arc<dyn CopilotRequestHandler>,
1007 client: OnceLock<Weak<ClientInner>>,
1008 pending: Mutex<HashMap<String, Arc<CopilotRequestExchange>>>,
1009}
1010
1011impl CopilotRequestDispatcher {
1012 pub(crate) fn new(handler: Arc<dyn CopilotRequestHandler>) -> Self {
1013 Self {
1014 handler,
1015 client: OnceLock::new(),
1016 pending: Mutex::new(HashMap::new()),
1017 }
1018 }
1019
1020 pub(crate) fn set_client(&self, client: Weak<ClientInner>) {
1021 let _ = self.client.set(client);
1022 }
1023
1024 fn client(&self) -> Option<Client> {
1025 self.client
1026 .get()
1027 .and_then(Weak::upgrade)
1028 .map(Client::from_inner)
1029 }
1030
1031 fn client_weak(&self) -> Weak<ClientInner> {
1032 self.client.get().cloned().unwrap_or_else(Weak::new)
1033 }
1034
1035 pub(crate) async fn dispatch(self: &Arc<Self>, request: JsonRpcRequest) {
1036 match request.method.as_str() {
1037 METHOD_HTTP_REQUEST_START => self.handle_start(request).await,
1038 METHOD_HTTP_REQUEST_CHUNK => self.handle_chunk(request).await,
1039 other => {
1040 warn!(method = other, "unknown llmInference request method");
1041 self.send_error(request.id, "unknown llmInference method")
1042 .await;
1043 }
1044 }
1045 }
1046
1047 fn get_or_create_exchange(&self, request_id: String) -> Arc<CopilotRequestExchange> {
1048 self.pending
1054 .lock()
1055 .entry(request_id.clone())
1056 .or_insert_with(|| {
1057 Arc::new(CopilotRequestExchange::new(request_id, self.client_weak()))
1058 })
1059 .clone()
1060 }
1061
1062 async fn handle_start(self: &Arc<Self>, request: JsonRpcRequest) {
1063 let id = request.id;
1064 let Some(params) = parse_params::<LlmInferenceHttpRequestStartRequest>(&request) else {
1065 self.send_error(id, "invalid llmInference.httpRequestStart params")
1066 .await;
1067 return;
1068 };
1069
1070 let request_id = params.request_id.clone().into_inner();
1073 let exchange = self.get_or_create_exchange(request_id.clone());
1074 exchange.set_context(params);
1075
1076 let handler = self.handler.clone();
1077 let dispatcher = Arc::clone(self);
1078 let exchange_for_task = exchange.clone();
1079 tokio::spawn(async move {
1080 let result = drive_exchange(&exchange_for_task, &handler).await;
1081 finalize_exchange(&exchange_for_task, result).await;
1082 dispatcher.remove_pending(&request_id);
1083 });
1084
1085 self.ack(id).await;
1086 }
1087
1088 async fn handle_chunk(&self, request: JsonRpcRequest) {
1089 let id = request.id;
1090 let Some(params) = parse_params::<LlmInferenceHttpRequestChunkRequest>(&request) else {
1091 self.send_error(id, "invalid llmInference.httpRequestChunk params")
1092 .await;
1093 return;
1094 };
1095
1096 let exchange = self.get_or_create_exchange(params.request_id.to_string());
1099 apply_chunk(&exchange, ¶ms);
1100
1101 self.ack(id).await;
1102 }
1103
1104 fn remove_pending(&self, request_id: &str) {
1105 self.pending.lock().remove(request_id);
1106 }
1107
1108 async fn ack(&self, id: u64) {
1109 let Some(client) = self.client() else {
1110 return;
1111 };
1112 let _ = client
1113 .send_response(&JsonRpcResponse {
1114 jsonrpc: "2.0".to_string(),
1115 id,
1116 result: Some(serde_json::json!({})),
1117 error: None,
1118 })
1119 .await;
1120 }
1121
1122 async fn send_error(&self, id: u64, message: &str) {
1123 let Some(client) = self.client() else {
1124 return;
1125 };
1126 let _ = client
1127 .send_response(&JsonRpcResponse {
1128 jsonrpc: "2.0".to_string(),
1129 id,
1130 result: None,
1131 error: Some(crate::JsonRpcError {
1132 code: error_codes::INTERNAL_ERROR,
1133 message: message.to_string(),
1134 data: None,
1135 }),
1136 })
1137 .await;
1138 }
1139}
1140
1141fn apply_chunk(exchange: &CopilotRequestExchange, params: &LlmInferenceHttpRequestChunkRequest) {
1144 if params.cancel == Some(true) {
1145 exchange.push_cancel();
1146 return;
1147 }
1148
1149 if !params.data.is_empty() {
1150 let decoded = if params.binary == Some(true) {
1151 match base64::engine::general_purpose::STANDARD.decode(params.data.as_bytes()) {
1152 Ok(bytes) => bytes,
1153 Err(e) => {
1154 warn!(error = %e, "failed to decode base64 llmInference body chunk");
1155 return;
1156 }
1157 }
1158 } else {
1159 params.data.clone().into_bytes()
1160 };
1161 exchange.push_chunk(decoded);
1162 }
1163
1164 if params.end == Some(true) {
1165 exchange.push_end();
1166 }
1167}
1168
1169fn parse_params<T: serde::de::DeserializeOwned>(request: &JsonRpcRequest) -> Option<T> {
1170 request
1171 .params
1172 .as_ref()
1173 .and_then(|p| serde_json::from_value(p.clone()).ok())
1174}
1175
1176fn headers_from_wire(wire: &HashMap<String, Vec<String>>) -> HeaderMap {
1179 let mut headers = HeaderMap::new();
1180 for (name, values) in wire {
1181 let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
1182 continue;
1183 };
1184 for value in values {
1185 let Ok(header_value) = HeaderValue::from_str(value) else {
1186 continue;
1187 };
1188 headers.append(header_name.clone(), header_value);
1189 }
1190 }
1191 headers
1192}
1193
1194fn headers_to_wire(headers: &HeaderMap) -> HashMap<String, Vec<String>> {
1197 let mut wire: HashMap<String, Vec<String>> = HashMap::new();
1198 for (name, value) in headers {
1199 let Ok(value) = value.to_str() else {
1200 continue;
1201 };
1202 wire.entry(name.as_str().to_string())
1203 .or_default()
1204 .push(value.to_string());
1205 }
1206 wire
1207}