rho-coding-agent 0.15.0

A lightweight agent harness inspired by Pi
use futures_util::StreamExt;
use reqwest::StatusCode;

use crate::{
    auth::github_copilot_token::{GitHubCopilotAuthManager, GitHubCopilotAuthMaterial},
    model::{
        openai::{
            convert::{convert_openai_response, to_openai_message, to_openai_tool},
            stream::{convert_streamed_response, handle_openai_stream_line, trim_sse_line_end},
            types::{ChatRequest, ChatResponse, ChatStreamOptions},
        },
        ModelError, ModelEvent, ModelProvider, ModelRequest, ModelResponse,
    },
};

const DEFAULT_COPILOT_CHAT_COMPLETIONS_URL: &str = "https://api.githubcopilot.com/chat/completions";
const USER_AGENT: &str = concat!("rho/", env!("CARGO_PKG_VERSION"));
const EDITOR_VERSION: &str = concat!("rho/", env!("CARGO_PKG_VERSION"));
const EDITOR_PLUGIN_VERSION: &str = concat!("rho/", env!("CARGO_PKG_VERSION"));
const COPILOT_INTEGRATION_ID: &str = "vscode-chat";

pub struct GitHubCopilotProvider {
    client: reqwest::Client,
    auth: GitHubCopilotAuthManager,
    model: String,
}

impl GitHubCopilotProvider {
    pub fn new(model: String, auth: GitHubCopilotAuthManager) -> Result<Self, ModelError> {
        auth.ensure_auth_available()?;
        Ok(Self {
            client: reqwest::Client::new(),
            auth,
            model,
        })
    }

    #[cfg(test)]
    fn new_with_client(
        model: String,
        auth: GitHubCopilotAuthManager,
        client: reqwest::Client,
    ) -> Self {
        Self {
            client,
            auth,
            model,
        }
    }

    fn chat_request(&self, request: ModelRequest, stream: bool) -> Result<ChatRequest, ModelError> {
        let messages = request
            .messages
            .into_iter()
            .map(to_openai_message)
            .collect::<Result<Vec<_>, _>>()?;
        let tools = request
            .tools
            .into_iter()
            .map(to_openai_tool)
            .collect::<Vec<_>>();
        let has_tools = !tools.is_empty();
        Ok(ChatRequest {
            model: self.model.clone(),
            messages,
            tools: has_tools.then_some(tools),
            tool_choice: has_tools.then_some("auto"),
            stream,
            stream_options: stream.then_some(ChatStreamOptions {
                include_usage: true,
            }),
        })
    }

    fn apply_headers(
        &self,
        builder: reqwest::RequestBuilder,
        auth: &GitHubCopilotAuthMaterial,
    ) -> reqwest::RequestBuilder {
        builder
            .bearer_auth(&auth.token)
            .header("Accept", "application/json")
            .header("User-Agent", USER_AGENT)
            .header("Editor-Version", EDITOR_VERSION)
            .header("Editor-Plugin-Version", EDITOR_PLUGIN_VERSION)
            .header("Copilot-Integration-Id", COPILOT_INTEGRATION_ID)
    }

    async fn send_chat_once(
        &self,
        body: &ChatRequest,
        auth: &GitHubCopilotAuthMaterial,
    ) -> Result<reqwest::Response, ModelError> {
        let endpoint = if auth.chat_endpoint.trim().is_empty() {
            DEFAULT_COPILOT_CHAT_COMPLETIONS_URL
        } else {
            auth.chat_endpoint.as_str()
        };
        Ok(self
            .apply_headers(self.client.post(endpoint), auth)
            .json(body)
            .send()
            .await?)
    }

    async fn send_chat_with_retry(
        &self,
        body: ChatRequest,
        auth: GitHubCopilotAuthMaterial,
    ) -> Result<reqwest::Response, ModelError> {
        let response = self.send_chat_once(&body, &auth).await?;
        if response.status() != StatusCode::UNAUTHORIZED {
            return Ok(response);
        }
        if let Some(refreshed) = self.auth.force_refresh(&self.client).await? {
            return self.send_chat_once(&body, &refreshed).await;
        }
        Ok(response)
    }
}

#[async_trait::async_trait(?Send)]
impl ModelProvider for GitHubCopilotProvider {
    async fn send_turn(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
        let body = self.chat_request(request, false)?;
        let auth = self.auth.auth_material(&self.client).await?;
        let response = self.send_chat_with_retry(body, auth).await?;
        if !response.status().is_success() {
            return Err(http_status_error(response).await);
        }
        let response: ChatResponse = response.json().await?;
        convert_openai_response(response)
    }

    async fn send_turn_stream(
        &self,
        request: ModelRequest,
        on_event: &mut dyn FnMut(ModelEvent) -> Result<(), ModelError>,
    ) -> Result<ModelResponse, ModelError> {
        let body = self.chat_request(request, true)?;
        let auth = self.auth.auth_material(&self.client).await?;
        let response = self.send_chat_with_retry(body, auth).await?;
        if !response.status().is_success() {
            return Err(http_status_error(response).await);
        }

        let mut text = String::new();
        let mut tool_calls = Vec::new();
        let mut buffer = Vec::new();
        let mut stream = response.bytes_stream();
        while let Some(chunk) = stream.next().await {
            buffer.extend_from_slice(&chunk?);
            while let Some(newline) = buffer.iter().position(|byte| *byte == b'\n') {
                let mut line = buffer.drain(..=newline).collect::<Vec<_>>();
                trim_sse_line_end(&mut line);
                let line = std::str::from_utf8(&line).map_err(|err| {
                    ModelError::InvalidResponse(format!(
                        "streamed response contained invalid utf-8: {err}"
                    ))
                })?;
                handle_openai_stream_line(line, &mut text, &mut tool_calls, on_event)?;
            }
        }
        if !buffer.is_empty() {
            trim_sse_line_end(&mut buffer);
            let line = std::str::from_utf8(&buffer).map_err(|err| {
                ModelError::InvalidResponse(format!(
                    "streamed response contained invalid utf-8: {err}"
                ))
            })?;
            handle_openai_stream_line(line, &mut text, &mut tool_calls, on_event)?;
        }

        convert_streamed_response(text, tool_calls)
    }
}

async fn http_status_error(response: reqwest::Response) -> ModelError {
    let status = response.status();
    let body = response.text().await.unwrap_or_default();
    if status == StatusCode::UNAUTHORIZED {
        ModelError::MissingGithubCopilotAuth
    } else {
        ModelError::HttpStatus { status, body }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::{Arc, Mutex};

    use crate::{
        credentials::{save_github_copilot_tokens, GitHubCopilotTokens, MemoryCredentialStore},
        model::{ContentBlock, Message},
    };
    use tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        net::TcpListener,
    };

    #[test]
    fn chat_request_preserves_model_and_streaming_flag() {
        let store = Arc::new(crate::credentials::MemoryCredentialStore::default());
        save_github_copilot_tokens(
            store.as_ref(),
            &GitHubCopilotTokens {
                github_access_token: "github".into(),
                github_refresh_token: None,
                github_expires_at_unix: None,
                copilot_token: Some("copilot".into()),
                copilot_expires_at_unix: Some(4_102_444_800),
                copilot_refresh_after_unix: None,
                copilot_token_endpoint: None,
                copilot_chat_endpoint: None,
                copilot_models_endpoint: None,
            },
        )
        .unwrap();
        let provider =
            GitHubCopilotProvider::new("gpt-4.1".into(), GitHubCopilotAuthManager::new(store))
                .unwrap();

        let body = provider
            .chat_request(
                ModelRequest {
                    messages: vec![Message::user_text("hello")],
                    tools: Vec::new(),
                    prompt_cache_key: None,
                },
                true,
            )
            .unwrap();

        assert_eq!(body.model, "gpt-4.1");
        assert!(body.stream);
        assert!(body.stream_options.is_some());
    }

    #[test]
    fn provider_construction_requires_available_auth() {
        let result = GitHubCopilotProvider::new(
            "gpt-4.1".into(),
            GitHubCopilotAuthManager::new_with_env_token(
                Arc::new(MemoryCredentialStore::default()),
                None,
            ),
        );

        assert!(matches!(result, Err(ModelError::MissingGithubCopilotAuth)));
    }

    #[tokio::test]
    async fn chat_retries_once_after_unauthorized_with_refreshed_endpoint() {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let base_url = format!("http://{}", listener.local_addr().unwrap());
        let requests = Arc::new(Mutex::new(Vec::new()));
        let requests_for_server = Arc::clone(&requests);
        let base_url_for_server = base_url.clone();
        tokio::spawn(async move {
            for index in 0..4 {
                let (mut stream, _) = listener.accept().await.unwrap();
                let mut buffer = [0; 4096];
                let bytes = stream.read(&mut buffer).await.unwrap();
                let request = String::from_utf8_lossy(&buffer[..bytes]).to_string();
                requests_for_server.lock().unwrap().push(request);
                let body = match index {
                    0 => format!(
                        "{{\"token\":\"first\",\"endpoints\":{{\"chat\":\"{base_url_for_server}/chat\"}}}}"
                    ),
                    1 => String::new(),
                    2 => format!(
                        "{{\"token\":\"second\",\"endpoints\":{{\"chat\":\"{base_url_for_server}/chat\"}}}}"
                    ),
                    3 => r#"{"choices":[{"message":{"content":"ok"}}]}"#.to_string(),
                    _ => unreachable!(),
                };
                let status = if index == 1 {
                    "401 Unauthorized"
                } else {
                    "200 OK"
                };
                let reply = format!(
                    "HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
                    body.len(), body
                );
                stream.write_all(reply.as_bytes()).await.unwrap();
            }
        });
        let store = Arc::new(MemoryCredentialStore::default());
        save_github_copilot_tokens(
            store.as_ref(),
            &GitHubCopilotTokens {
                github_access_token: "github".into(),
                github_refresh_token: None,
                github_expires_at_unix: None,
                copilot_token: None,
                copilot_expires_at_unix: None,
                copilot_refresh_after_unix: None,
                copilot_token_endpoint: Some(base_url.clone()),
                copilot_chat_endpoint: None,
                copilot_models_endpoint: None,
            },
        )
        .unwrap();
        let provider = GitHubCopilotProvider::new_with_client(
            "gpt-4.1".into(),
            GitHubCopilotAuthManager::new(store),
            reqwest::Client::new(),
        );

        let response = provider
            .send_turn(ModelRequest {
                messages: vec![Message::user_text("hello")],
                tools: Vec::new(),
                prompt_cache_key: None,
            })
            .await
            .unwrap();

        assert!(matches!(
            response,
            ModelResponse::Assistant(blocks) if matches!(blocks.as_slice(), [ContentBlock::Text(text)] if text == "ok")
        ));
        let requests = requests.lock().unwrap();
        assert_eq!(
            requests
                .iter()
                .filter(|request| request.contains("POST /chat"))
                .count(),
            2
        );
        assert!(requests
            .iter()
            .any(|request| request.contains("authorization: Bearer first")));
        assert!(requests
            .iter()
            .any(|request| request.contains("authorization: Bearer second")));
    }
}