Skip to main content

llm/oauth/
browser.rs

1use super::error::OAuthError;
2use super::handler::{OAuthCallback, OAuthHandler};
3use futures::future::BoxFuture;
4use std::process::Command;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7
8/// Default `OAuthHandler` that opens the system browser and listens
9/// for the OAuth callback on a dynamically-assigned local port.
10pub struct BrowserOAuthHandler {
11    listener: TcpListener,
12    redirect_uri: String,
13}
14
15impl BrowserOAuthHandler {
16    pub fn new() -> Result<Self, std::io::Error> {
17        let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?;
18        let port = std_listener.local_addr()?.port();
19        std_listener.set_nonblocking(true)?;
20        let listener = TcpListener::from_std(std_listener)?;
21        Ok(Self { listener, redirect_uri: format!("http://127.0.0.1:{port}/oauth2callback") })
22    }
23
24    /// Create a handler bound to a specific port with a custom redirect URI.
25    ///
26    /// Use this when the OAuth provider has a fixed redirect URI registered
27    /// (e.g. `http://localhost:1455/auth/callback` for Codex).
28    pub fn with_redirect_uri(redirect_uri: impl Into<String>, port: u16) -> Result<Self, std::io::Error> {
29        let std_listener = std::net::TcpListener::bind(format!("127.0.0.1:{port}"))?;
30        std_listener.set_nonblocking(true)?;
31        let listener = TcpListener::from_std(std_listener)?;
32        Ok(Self { listener, redirect_uri: redirect_uri.into() })
33    }
34}
35
36impl OAuthHandler for BrowserOAuthHandler {
37    fn redirect_uri(&self) -> &str {
38        &self.redirect_uri
39    }
40
41    fn authorize(&self, auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
42        let auth_url = auth_url.to_string();
43        Box::pin(async move {
44            if let Err(e) = open_browser(&auth_url) {
45                tracing::warn!("Failed to open browser: {e}. Open manually: {auth_url}");
46            }
47
48            accept_oauth_callback(&self.listener).await
49        })
50    }
51}
52
53/// Accept a single OAuth callback on an already-bound listener.
54///
55/// Waits for one HTTP request, parses the authorization code and state,
56/// sends a success response, and returns the callback data.
57pub async fn accept_oauth_callback(listener: &TcpListener) -> Result<OAuthCallback, OAuthError> {
58    let (mut socket, _) = listener.accept().await?;
59
60    let mut reader = BufReader::new(&mut socket);
61    let mut request_line = String::new();
62    reader.read_line(&mut request_line).await?;
63
64    let callback = parse_callback_from_request(&request_line)?;
65
66    socket.write_all(create_success_response().as_bytes()).await?;
67
68    Ok(callback)
69}
70
71/// Start a local callback server to capture the OAuth authorization code and state
72///
73/// Listens on the specified port and waits for the OAuth redirect.
74/// Returns the authorization code and state (CSRF token) from the callback URL.
75pub async fn wait_for_callback(port: u16) -> Result<OAuthCallback, OAuthError> {
76    let addr = format!("127.0.0.1:{port}");
77    let listener = TcpListener::bind(&addr).await?;
78    accept_oauth_callback(&listener).await
79}
80
81/// Open a URL in the default browser
82pub fn open_browser(url: &str) -> Result<(), OAuthError> {
83    #[cfg(target_os = "macos")]
84    {
85        Command::new("open").arg(url).spawn().map_err(std::io::Error::other)?;
86    }
87
88    #[cfg(target_os = "linux")]
89    {
90        Command::new("xdg-open").arg(url).spawn().map_err(std::io::Error::other)?;
91    }
92
93    #[cfg(target_os = "windows")]
94    {
95        Command::new("cmd").args(["/C", "start", url]).spawn().map_err(std::io::Error::other)?;
96    }
97
98    Ok(())
99}
100
101/// Parse the authorization code and state from the HTTP request line
102fn parse_callback_from_request(request_line: &str) -> Result<OAuthCallback, OAuthError> {
103    // Request format: GET /oauth2callback?code=XXX&state=YYY HTTP/1.1
104    let parts: Vec<&str> = request_line.split_whitespace().collect();
105    if parts.len() < 2 {
106        return Err(OAuthError::InvalidCallback("Invalid HTTP request format".to_string()));
107    }
108
109    let path = parts[1];
110    let query_start =
111        path.find('?').ok_or_else(|| OAuthError::InvalidCallback("No query parameters in callback".to_string()))?;
112
113    let query = &path[query_start + 1..];
114
115    // Check for error in callback
116    for param in query.split('&') {
117        if let Some((key, value)) = param.split_once('=')
118            && key == "error"
119        {
120            let error_desc = query
121                .split('&')
122                .find_map(|p| {
123                    p.split_once('=').filter(|(k, _)| *k == "error_description").map(|(_, v)| urlencoding_decode(v))
124                })
125                .unwrap_or_else(|| value.to_string());
126            return Err(OAuthError::InvalidCallback(format!("OAuth error: {error_desc}")));
127        }
128    }
129
130    // Extract code and state
131    let mut code = None;
132    let mut state = None;
133
134    for param in query.split('&') {
135        if let Some((key, value)) = param.split_once('=') {
136            match key {
137                "code" => code = Some(urlencoding_decode(value)),
138                "state" => state = Some(urlencoding_decode(value)),
139                _ => {}
140            }
141        }
142    }
143
144    let code = code.ok_or_else(|| OAuthError::InvalidCallback("No authorization code in callback".into()))?;
145    let state = state.ok_or_else(|| OAuthError::InvalidCallback("No state parameter in callback".into()))?;
146
147    Ok(OAuthCallback { code, state })
148}
149
150/// Simple URL decoding (handles %XX escapes)
151fn urlencoding_decode(s: &str) -> String {
152    let mut result = String::with_capacity(s.len());
153    let mut chars = s.chars().peekable();
154
155    while let Some(c) = chars.next() {
156        if c == '%' {
157            let hex: String = chars.by_ref().take(2).collect();
158            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
159                result.push(byte as char);
160            } else {
161                result.push('%');
162                result.push_str(&hex);
163            }
164        } else if c == '+' {
165            result.push(' ');
166        } else {
167            result.push(c);
168        }
169    }
170
171    result
172}
173
174/// Create an HTML success response
175fn create_success_response() -> String {
176    let body = include_str!("oauth_success.html");
177
178    format!(
179        "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
180        body.len(),
181        body
182    )
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn parse_callback_from_valid_request() {
191        let request = "GET /oauth2callback?code=4%2F0AYWS-abc123&state=verifier HTTP/1.1\r\n";
192        let callback = parse_callback_from_request(request).unwrap();
193        assert_eq!(callback.code, "4/0AYWS-abc123");
194        assert_eq!(callback.state, "verifier");
195    }
196
197    #[test]
198    fn parse_callback_handles_plus_encoding() {
199        let request = "GET /oauth2callback?code=hello+world&state=test+state HTTP/1.1\r\n";
200        let callback = parse_callback_from_request(request).unwrap();
201        assert_eq!(callback.code, "hello world");
202        assert_eq!(callback.state, "test state");
203    }
204
205    #[test]
206    fn parse_callback_returns_error_for_oauth_error() {
207        let request = "GET /oauth2callback?error=access_denied&error_description=User+denied HTTP/1.1\r\n";
208        let result = parse_callback_from_request(request);
209        assert!(result.is_err());
210        let err = result.unwrap_err().to_string();
211        assert!(err.contains("User denied"));
212    }
213
214    #[test]
215    fn parse_callback_returns_error_for_missing_code() {
216        let request = "GET /oauth2callback?state=verifier HTTP/1.1\r\n";
217        let result = parse_callback_from_request(request);
218        assert!(result.is_err());
219        assert!(result.unwrap_err().to_string().contains("No authorization code"));
220    }
221
222    #[test]
223    fn parse_callback_returns_error_for_missing_state() {
224        let request = "GET /oauth2callback?code=abc123 HTTP/1.1\r\n";
225        let result = parse_callback_from_request(request);
226        assert!(result.is_err());
227        assert!(result.unwrap_err().to_string().contains("No state parameter"));
228    }
229
230    #[tokio::test]
231    async fn with_redirect_uri_binds_to_specified_port() {
232        let handler = BrowserOAuthHandler::with_redirect_uri("http://localhost:9999/callback", 0).unwrap();
233        assert_eq!(handler.redirect_uri(), "http://localhost:9999/callback");
234    }
235
236    #[test]
237    fn urlencoding_decode_handles_percent() {
238        assert_eq!(urlencoding_decode("hello%20world"), "hello world");
239        assert_eq!(urlencoding_decode("a%2Fb"), "a/b");
240    }
241
242    #[test]
243    fn urlencoding_decode_handles_plus() {
244        assert_eq!(urlencoding_decode("hello+world"), "hello world");
245    }
246}