Skip to main content

llm/oauth/
browser.rs

1use super::error::OAuthError;
2use super::handler::{OAuthCallback, OAuthHandler};
3use futures::future::BoxFuture;
4use std::process::Command;
5use std::time::Duration;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8use tokio::time::timeout;
9
10/// Default `OAuthHandler` that opens the system browser and listens
11/// for the OAuth callback on a dynamically-assigned local port.
12pub struct BrowserOAuthHandler {
13    listener: TcpListener,
14    redirect_uri: String,
15}
16
17impl BrowserOAuthHandler {
18    pub fn new() -> Result<Self, std::io::Error> {
19        let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?;
20        let port = std_listener.local_addr()?.port();
21        std_listener.set_nonblocking(true)?;
22        let listener = TcpListener::from_std(std_listener)?;
23        Ok(Self { listener, redirect_uri: format!("http://127.0.0.1:{port}/oauth2callback") })
24    }
25
26    /// Create a handler bound to a specific port with a custom redirect URI.
27    ///
28    /// Use this when the OAuth provider has a fixed redirect URI registered
29    /// (e.g. `http://localhost:1455/auth/callback` for Codex).
30    pub fn with_redirect_uri(redirect_uri: impl Into<String>, port: u16) -> Result<Self, std::io::Error> {
31        let std_listener = std::net::TcpListener::bind(format!("127.0.0.1:{port}"))?;
32        std_listener.set_nonblocking(true)?;
33        let listener = TcpListener::from_std(std_listener)?;
34        Ok(Self { listener, redirect_uri: redirect_uri.into() })
35    }
36}
37
38impl OAuthHandler for BrowserOAuthHandler {
39    fn redirect_uri(&self) -> &str {
40        &self.redirect_uri
41    }
42
43    fn authorize(&self, auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
44        let auth_url = auth_url.to_string();
45        Box::pin(async move {
46            if let Err(e) = open_browser(&auth_url) {
47                tracing::warn!("Failed to open browser: {e}. Open manually: {auth_url}");
48            }
49
50            accept_oauth_callback(&self.listener).await
51        })
52    }
53}
54
55/// Accept an OAuth callback on an already-bound listener.
56///
57/// Loops until it receives a valid OAuth callback, skipping stale connections
58/// left over from previous flows (e.g. browser favicon requests, preconnect
59/// sockets, or closed connections).
60pub async fn accept_oauth_callback(listener: &TcpListener) -> Result<OAuthCallback, OAuthError> {
61    loop {
62        let (mut socket, _) = listener.accept().await?;
63        let request_line = {
64            let mut reader = BufReader::new(&mut socket);
65            let mut line = String::new();
66            match timeout(Duration::from_secs(2), reader.read_line(&mut line)).await {
67                Ok(Ok(1..)) => line,
68                _ => continue,
69            }
70        };
71
72        match parse_callback_from_request(&request_line) {
73            Ok(callback) => {
74                let _ = socket.write_all(create_success_response().as_bytes()).await;
75                return Ok(callback);
76            }
77            Err(e) => {
78                if request_line.contains('?') {
79                    return Err(e);
80                }
81            }
82        }
83    }
84}
85
86/// Start a local callback server to capture the OAuth authorization code and state
87///
88/// Listens on the specified port and waits for the OAuth redirect.
89/// Returns the authorization code and state (CSRF token) from the callback URL.
90pub async fn wait_for_callback(port: u16) -> Result<OAuthCallback, OAuthError> {
91    let addr = format!("127.0.0.1:{port}");
92    let listener = TcpListener::bind(&addr).await?;
93    accept_oauth_callback(&listener).await
94}
95
96/// Open a URL in the default browser
97pub fn open_browser(url: &str) -> Result<(), OAuthError> {
98    #[cfg(target_os = "macos")]
99    {
100        Command::new("open").arg(url).spawn().map_err(std::io::Error::other)?;
101    }
102
103    #[cfg(target_os = "linux")]
104    {
105        Command::new("xdg-open").arg(url).spawn().map_err(std::io::Error::other)?;
106    }
107
108    #[cfg(target_os = "windows")]
109    {
110        Command::new("cmd").args(["/C", "start", url]).spawn().map_err(std::io::Error::other)?;
111    }
112
113    Ok(())
114}
115
116/// Parse the authorization code and state from the HTTP request line
117fn parse_callback_from_request(request_line: &str) -> Result<OAuthCallback, OAuthError> {
118    // Request format: GET /oauth2callback?code=XXX&state=YYY HTTP/1.1
119    let parts: Vec<&str> = request_line.split_whitespace().collect();
120    if parts.len() < 2 {
121        return Err(OAuthError::InvalidCallback("Invalid HTTP request format".to_string()));
122    }
123
124    let path = parts[1];
125    let query_start =
126        path.find('?').ok_or_else(|| OAuthError::InvalidCallback("No query parameters in callback".to_string()))?;
127
128    let query = &path[query_start + 1..];
129
130    // Check for error in callback
131    for param in query.split('&') {
132        if let Some((key, value)) = param.split_once('=')
133            && key == "error"
134        {
135            let error_desc = query
136                .split('&')
137                .find_map(|p| {
138                    p.split_once('=').filter(|(k, _)| *k == "error_description").map(|(_, v)| urlencoding_decode(v))
139                })
140                .unwrap_or_else(|| value.to_string());
141            return Err(OAuthError::InvalidCallback(format!("OAuth error: {error_desc}")));
142        }
143    }
144
145    // Extract code and state
146    let mut code = None;
147    let mut state = None;
148
149    for param in query.split('&') {
150        if let Some((key, value)) = param.split_once('=') {
151            match key {
152                "code" => code = Some(urlencoding_decode(value)),
153                "state" => state = Some(urlencoding_decode(value)),
154                _ => {}
155            }
156        }
157    }
158
159    let code = code.ok_or_else(|| OAuthError::InvalidCallback("No authorization code in callback".into()))?;
160    let state = state.ok_or_else(|| OAuthError::InvalidCallback("No state parameter in callback".into()))?;
161
162    Ok(OAuthCallback { code, state })
163}
164
165/// Simple URL decoding (handles %XX escapes)
166fn urlencoding_decode(s: &str) -> String {
167    let mut result = String::with_capacity(s.len());
168    let mut chars = s.chars().peekable();
169
170    while let Some(c) = chars.next() {
171        if c == '%' {
172            let hex: String = chars.by_ref().take(2).collect();
173            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
174                result.push(byte as char);
175            } else {
176                result.push('%');
177                result.push_str(&hex);
178            }
179        } else if c == '+' {
180            result.push(' ');
181        } else {
182            result.push(c);
183        }
184    }
185
186    result
187}
188
189/// Create an HTML success response
190fn create_success_response() -> String {
191    let body = include_str!("oauth_success.html");
192
193    format!(
194        "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
195        body.len(),
196        body
197    )
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn parse_callback_from_valid_request() {
206        let request = "GET /oauth2callback?code=4%2F0AYWS-abc123&state=verifier HTTP/1.1\r\n";
207        let callback = parse_callback_from_request(request).unwrap();
208        assert_eq!(callback.code, "4/0AYWS-abc123");
209        assert_eq!(callback.state, "verifier");
210    }
211
212    #[test]
213    fn parse_callback_handles_plus_encoding() {
214        let request = "GET /oauth2callback?code=hello+world&state=test+state HTTP/1.1\r\n";
215        let callback = parse_callback_from_request(request).unwrap();
216        assert_eq!(callback.code, "hello world");
217        assert_eq!(callback.state, "test state");
218    }
219
220    #[test]
221    fn parse_callback_returns_error_for_oauth_error() {
222        let request = "GET /oauth2callback?error=access_denied&error_description=User+denied HTTP/1.1\r\n";
223        let result = parse_callback_from_request(request);
224        assert!(result.is_err());
225        let err = result.unwrap_err().to_string();
226        assert!(err.contains("User denied"));
227    }
228
229    #[test]
230    fn parse_callback_returns_error_for_missing_code() {
231        let request = "GET /oauth2callback?state=verifier HTTP/1.1\r\n";
232        let result = parse_callback_from_request(request);
233        assert!(result.is_err());
234        assert!(result.unwrap_err().to_string().contains("No authorization code"));
235    }
236
237    #[test]
238    fn parse_callback_returns_error_for_missing_state() {
239        let request = "GET /oauth2callback?code=abc123 HTTP/1.1\r\n";
240        let result = parse_callback_from_request(request);
241        assert!(result.is_err());
242        assert!(result.unwrap_err().to_string().contains("No state parameter"));
243    }
244
245    #[tokio::test]
246    async fn with_redirect_uri_binds_to_specified_port() {
247        let handler = BrowserOAuthHandler::with_redirect_uri("http://localhost:9999/callback", 0).unwrap();
248        assert_eq!(handler.redirect_uri(), "http://localhost:9999/callback");
249    }
250
251    #[test]
252    fn urlencoding_decode_handles_percent() {
253        assert_eq!(urlencoding_decode("hello%20world"), "hello world");
254        assert_eq!(urlencoding_decode("a%2Fb"), "a/b");
255    }
256
257    #[test]
258    fn urlencoding_decode_handles_plus() {
259        assert_eq!(urlencoding_decode("hello+world"), "hello world");
260    }
261
262    #[tokio::test]
263    async fn accept_oauth_callback_skips_stale_favicon_request() {
264        use tokio::io::AsyncWriteExt;
265
266        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
267        let port = listener.local_addr().unwrap().port();
268        let handle = tokio::spawn(async move { accept_oauth_callback(&listener).await });
269
270        let mut stale = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
271        stale.write_all(b"GET /favicon.ico HTTP/1.1\r\n").await.unwrap();
272
273        let mut real = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
274        real.write_all(b"GET /oauth2callback?code=abc&state=xyz HTTP/1.1\r\n").await.unwrap();
275
276        let callback = handle.await.unwrap().unwrap();
277        assert_eq!(callback.code, "abc");
278        assert_eq!(callback.state, "xyz");
279    }
280
281    #[tokio::test]
282    async fn accept_oauth_callback_skips_closed_connection() {
283        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
284        let port = listener.local_addr().unwrap().port();
285        let handle = tokio::spawn(async move { accept_oauth_callback(&listener).await });
286
287        {
288            let _stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
289        }
290
291        let mut real = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
292        real.write_all(b"GET /oauth2callback?code=def&state=uvw HTTP/1.1\r\n").await.unwrap();
293
294        let callback = handle.await.unwrap().unwrap();
295        assert_eq!(callback.code, "def");
296        assert_eq!(callback.state, "uvw");
297    }
298}