1use super::error::OAuthError;
2use super::handler::{OAuthCallback, OAuthHandler};
3use futures::future::BoxFuture;
4use std::process::Command;
5use std::time::Duration;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8use tokio::time::timeout;
9
10pub struct BrowserOAuthHandler {
13 listener: TcpListener,
14 redirect_uri: String,
15}
16
17impl BrowserOAuthHandler {
18 pub fn new() -> Result<Self, std::io::Error> {
19 let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?;
20 let port = std_listener.local_addr()?.port();
21 std_listener.set_nonblocking(true)?;
22 let listener = TcpListener::from_std(std_listener)?;
23 Ok(Self { listener, redirect_uri: format!("http://127.0.0.1:{port}/oauth2callback") })
24 }
25
26 pub fn with_redirect_uri(redirect_uri: impl Into<String>, port: u16) -> Result<Self, std::io::Error> {
31 let std_listener = std::net::TcpListener::bind(format!("127.0.0.1:{port}"))?;
32 std_listener.set_nonblocking(true)?;
33 let listener = TcpListener::from_std(std_listener)?;
34 Ok(Self { listener, redirect_uri: redirect_uri.into() })
35 }
36}
37
38impl OAuthHandler for BrowserOAuthHandler {
39 fn redirect_uri(&self) -> &str {
40 &self.redirect_uri
41 }
42
43 fn authorize(&self, auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
44 let auth_url = auth_url.to_string();
45 Box::pin(async move {
46 if let Err(e) = open_browser(&auth_url) {
47 tracing::warn!("Failed to open browser: {e}. Open manually: {auth_url}");
48 }
49
50 accept_oauth_callback(&self.listener).await
51 })
52 }
53}
54
55pub async fn accept_oauth_callback(listener: &TcpListener) -> Result<OAuthCallback, OAuthError> {
61 loop {
62 let (mut socket, _) = listener.accept().await?;
63 let request_line = {
64 let mut reader = BufReader::new(&mut socket);
65 let mut line = String::new();
66 match timeout(Duration::from_secs(2), reader.read_line(&mut line)).await {
67 Ok(Ok(1..)) => line,
68 _ => continue,
69 }
70 };
71
72 match parse_callback_from_request(&request_line) {
73 Ok(callback) => {
74 let _ = socket.write_all(create_success_response().as_bytes()).await;
75 return Ok(callback);
76 }
77 Err(e) => {
78 if request_line.contains('?') {
79 return Err(e);
80 }
81 }
82 }
83 }
84}
85
86pub async fn wait_for_callback(port: u16) -> Result<OAuthCallback, OAuthError> {
91 let addr = format!("127.0.0.1:{port}");
92 let listener = TcpListener::bind(&addr).await?;
93 accept_oauth_callback(&listener).await
94}
95
96pub fn open_browser(url: &str) -> Result<(), OAuthError> {
98 #[cfg(target_os = "macos")]
99 {
100 Command::new("open").arg(url).spawn().map_err(std::io::Error::other)?;
101 }
102
103 #[cfg(target_os = "linux")]
104 {
105 Command::new("xdg-open").arg(url).spawn().map_err(std::io::Error::other)?;
106 }
107
108 #[cfg(target_os = "windows")]
109 {
110 Command::new("cmd").args(["/C", "start", url]).spawn().map_err(std::io::Error::other)?;
111 }
112
113 Ok(())
114}
115
116fn parse_callback_from_request(request_line: &str) -> Result<OAuthCallback, OAuthError> {
118 let parts: Vec<&str> = request_line.split_whitespace().collect();
120 if parts.len() < 2 {
121 return Err(OAuthError::InvalidCallback("Invalid HTTP request format".to_string()));
122 }
123
124 let path = parts[1];
125 let query_start =
126 path.find('?').ok_or_else(|| OAuthError::InvalidCallback("No query parameters in callback".to_string()))?;
127
128 let query = &path[query_start + 1..];
129
130 for param in query.split('&') {
132 if let Some((key, value)) = param.split_once('=')
133 && key == "error"
134 {
135 let error_desc = query
136 .split('&')
137 .find_map(|p| {
138 p.split_once('=').filter(|(k, _)| *k == "error_description").map(|(_, v)| urlencoding_decode(v))
139 })
140 .unwrap_or_else(|| value.to_string());
141 return Err(OAuthError::InvalidCallback(format!("OAuth error: {error_desc}")));
142 }
143 }
144
145 let mut code = None;
147 let mut state = None;
148
149 for param in query.split('&') {
150 if let Some((key, value)) = param.split_once('=') {
151 match key {
152 "code" => code = Some(urlencoding_decode(value)),
153 "state" => state = Some(urlencoding_decode(value)),
154 _ => {}
155 }
156 }
157 }
158
159 let code = code.ok_or_else(|| OAuthError::InvalidCallback("No authorization code in callback".into()))?;
160 let state = state.ok_or_else(|| OAuthError::InvalidCallback("No state parameter in callback".into()))?;
161
162 Ok(OAuthCallback { code, state })
163}
164
165fn urlencoding_decode(s: &str) -> String {
167 let mut result = String::with_capacity(s.len());
168 let mut chars = s.chars().peekable();
169
170 while let Some(c) = chars.next() {
171 if c == '%' {
172 let hex: String = chars.by_ref().take(2).collect();
173 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
174 result.push(byte as char);
175 } else {
176 result.push('%');
177 result.push_str(&hex);
178 }
179 } else if c == '+' {
180 result.push(' ');
181 } else {
182 result.push(c);
183 }
184 }
185
186 result
187}
188
189fn create_success_response() -> String {
191 let body = include_str!("oauth_success.html");
192
193 format!(
194 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
195 body.len(),
196 body
197 )
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn parse_callback_from_valid_request() {
206 let request = "GET /oauth2callback?code=4%2F0AYWS-abc123&state=verifier HTTP/1.1\r\n";
207 let callback = parse_callback_from_request(request).unwrap();
208 assert_eq!(callback.code, "4/0AYWS-abc123");
209 assert_eq!(callback.state, "verifier");
210 }
211
212 #[test]
213 fn parse_callback_handles_plus_encoding() {
214 let request = "GET /oauth2callback?code=hello+world&state=test+state HTTP/1.1\r\n";
215 let callback = parse_callback_from_request(request).unwrap();
216 assert_eq!(callback.code, "hello world");
217 assert_eq!(callback.state, "test state");
218 }
219
220 #[test]
221 fn parse_callback_returns_error_for_oauth_error() {
222 let request = "GET /oauth2callback?error=access_denied&error_description=User+denied HTTP/1.1\r\n";
223 let result = parse_callback_from_request(request);
224 assert!(result.is_err());
225 let err = result.unwrap_err().to_string();
226 assert!(err.contains("User denied"));
227 }
228
229 #[test]
230 fn parse_callback_returns_error_for_missing_code() {
231 let request = "GET /oauth2callback?state=verifier HTTP/1.1\r\n";
232 let result = parse_callback_from_request(request);
233 assert!(result.is_err());
234 assert!(result.unwrap_err().to_string().contains("No authorization code"));
235 }
236
237 #[test]
238 fn parse_callback_returns_error_for_missing_state() {
239 let request = "GET /oauth2callback?code=abc123 HTTP/1.1\r\n";
240 let result = parse_callback_from_request(request);
241 assert!(result.is_err());
242 assert!(result.unwrap_err().to_string().contains("No state parameter"));
243 }
244
245 #[tokio::test]
246 async fn with_redirect_uri_binds_to_specified_port() {
247 let handler = BrowserOAuthHandler::with_redirect_uri("http://localhost:9999/callback", 0).unwrap();
248 assert_eq!(handler.redirect_uri(), "http://localhost:9999/callback");
249 }
250
251 #[test]
252 fn urlencoding_decode_handles_percent() {
253 assert_eq!(urlencoding_decode("hello%20world"), "hello world");
254 assert_eq!(urlencoding_decode("a%2Fb"), "a/b");
255 }
256
257 #[test]
258 fn urlencoding_decode_handles_plus() {
259 assert_eq!(urlencoding_decode("hello+world"), "hello world");
260 }
261
262 #[tokio::test]
263 async fn accept_oauth_callback_skips_stale_favicon_request() {
264 use tokio::io::AsyncWriteExt;
265
266 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
267 let port = listener.local_addr().unwrap().port();
268 let handle = tokio::spawn(async move { accept_oauth_callback(&listener).await });
269
270 let mut stale = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
271 stale.write_all(b"GET /favicon.ico HTTP/1.1\r\n").await.unwrap();
272
273 let mut real = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
274 real.write_all(b"GET /oauth2callback?code=abc&state=xyz HTTP/1.1\r\n").await.unwrap();
275
276 let callback = handle.await.unwrap().unwrap();
277 assert_eq!(callback.code, "abc");
278 assert_eq!(callback.state, "xyz");
279 }
280
281 #[tokio::test]
282 async fn accept_oauth_callback_skips_closed_connection() {
283 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
284 let port = listener.local_addr().unwrap().port();
285 let handle = tokio::spawn(async move { accept_oauth_callback(&listener).await });
286
287 {
288 let _stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
289 }
290
291 let mut real = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await.unwrap();
292 real.write_all(b"GET /oauth2callback?code=def&state=uvw HTTP/1.1\r\n").await.unwrap();
293
294 let callback = handle.await.unwrap().unwrap();
295 assert_eq!(callback.code, "def");
296 assert_eq!(callback.state, "uvw");
297 }
298}