1use super::error::OAuthError;
2use super::handler::{OAuthCallback, OAuthHandler};
3use futures::future::BoxFuture;
4use std::process::Command;
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7
8pub struct BrowserOAuthHandler {
11 listener: TcpListener,
12 redirect_uri: String,
13}
14
15impl BrowserOAuthHandler {
16 pub fn new() -> Result<Self, std::io::Error> {
17 let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?;
18 let port = std_listener.local_addr()?.port();
19 std_listener.set_nonblocking(true)?;
20 let listener = TcpListener::from_std(std_listener)?;
21 Ok(Self { listener, redirect_uri: format!("http://127.0.0.1:{port}/oauth2callback") })
22 }
23
24 pub fn with_redirect_uri(redirect_uri: impl Into<String>, port: u16) -> Result<Self, std::io::Error> {
29 let std_listener = std::net::TcpListener::bind(format!("127.0.0.1:{port}"))?;
30 std_listener.set_nonblocking(true)?;
31 let listener = TcpListener::from_std(std_listener)?;
32 Ok(Self { listener, redirect_uri: redirect_uri.into() })
33 }
34}
35
36impl OAuthHandler for BrowserOAuthHandler {
37 fn redirect_uri(&self) -> &str {
38 &self.redirect_uri
39 }
40
41 fn authorize(&self, auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
42 let auth_url = auth_url.to_string();
43 Box::pin(async move {
44 if let Err(e) = open_browser(&auth_url) {
45 tracing::warn!("Failed to open browser: {e}. Open manually: {auth_url}");
46 }
47
48 accept_oauth_callback(&self.listener).await
49 })
50 }
51}
52
53pub async fn accept_oauth_callback(listener: &TcpListener) -> Result<OAuthCallback, OAuthError> {
58 let (mut socket, _) = listener.accept().await?;
59
60 let mut reader = BufReader::new(&mut socket);
61 let mut request_line = String::new();
62 reader.read_line(&mut request_line).await?;
63
64 let callback = parse_callback_from_request(&request_line)?;
65
66 socket.write_all(create_success_response().as_bytes()).await?;
67
68 Ok(callback)
69}
70
71pub async fn wait_for_callback(port: u16) -> Result<OAuthCallback, OAuthError> {
76 let addr = format!("127.0.0.1:{port}");
77 let listener = TcpListener::bind(&addr).await?;
78 accept_oauth_callback(&listener).await
79}
80
81pub fn open_browser(url: &str) -> Result<(), OAuthError> {
83 #[cfg(target_os = "macos")]
84 {
85 Command::new("open").arg(url).spawn().map_err(std::io::Error::other)?;
86 }
87
88 #[cfg(target_os = "linux")]
89 {
90 Command::new("xdg-open").arg(url).spawn().map_err(std::io::Error::other)?;
91 }
92
93 #[cfg(target_os = "windows")]
94 {
95 Command::new("cmd").args(["/C", "start", url]).spawn().map_err(std::io::Error::other)?;
96 }
97
98 Ok(())
99}
100
101fn parse_callback_from_request(request_line: &str) -> Result<OAuthCallback, OAuthError> {
103 let parts: Vec<&str> = request_line.split_whitespace().collect();
105 if parts.len() < 2 {
106 return Err(OAuthError::InvalidCallback("Invalid HTTP request format".to_string()));
107 }
108
109 let path = parts[1];
110 let query_start =
111 path.find('?').ok_or_else(|| OAuthError::InvalidCallback("No query parameters in callback".to_string()))?;
112
113 let query = &path[query_start + 1..];
114
115 for param in query.split('&') {
117 if let Some((key, value)) = param.split_once('=')
118 && key == "error"
119 {
120 let error_desc = query
121 .split('&')
122 .find_map(|p| {
123 p.split_once('=').filter(|(k, _)| *k == "error_description").map(|(_, v)| urlencoding_decode(v))
124 })
125 .unwrap_or_else(|| value.to_string());
126 return Err(OAuthError::InvalidCallback(format!("OAuth error: {error_desc}")));
127 }
128 }
129
130 let mut code = None;
132 let mut state = None;
133
134 for param in query.split('&') {
135 if let Some((key, value)) = param.split_once('=') {
136 match key {
137 "code" => code = Some(urlencoding_decode(value)),
138 "state" => state = Some(urlencoding_decode(value)),
139 _ => {}
140 }
141 }
142 }
143
144 let code = code.ok_or_else(|| OAuthError::InvalidCallback("No authorization code in callback".into()))?;
145 let state = state.ok_or_else(|| OAuthError::InvalidCallback("No state parameter in callback".into()))?;
146
147 Ok(OAuthCallback { code, state })
148}
149
150fn urlencoding_decode(s: &str) -> String {
152 let mut result = String::with_capacity(s.len());
153 let mut chars = s.chars().peekable();
154
155 while let Some(c) = chars.next() {
156 if c == '%' {
157 let hex: String = chars.by_ref().take(2).collect();
158 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
159 result.push(byte as char);
160 } else {
161 result.push('%');
162 result.push_str(&hex);
163 }
164 } else if c == '+' {
165 result.push(' ');
166 } else {
167 result.push(c);
168 }
169 }
170
171 result
172}
173
174fn create_success_response() -> String {
176 let body = include_str!("oauth_success.html");
177
178 format!(
179 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
180 body.len(),
181 body
182 )
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn parse_callback_from_valid_request() {
191 let request = "GET /oauth2callback?code=4%2F0AYWS-abc123&state=verifier HTTP/1.1\r\n";
192 let callback = parse_callback_from_request(request).unwrap();
193 assert_eq!(callback.code, "4/0AYWS-abc123");
194 assert_eq!(callback.state, "verifier");
195 }
196
197 #[test]
198 fn parse_callback_handles_plus_encoding() {
199 let request = "GET /oauth2callback?code=hello+world&state=test+state HTTP/1.1\r\n";
200 let callback = parse_callback_from_request(request).unwrap();
201 assert_eq!(callback.code, "hello world");
202 assert_eq!(callback.state, "test state");
203 }
204
205 #[test]
206 fn parse_callback_returns_error_for_oauth_error() {
207 let request = "GET /oauth2callback?error=access_denied&error_description=User+denied HTTP/1.1\r\n";
208 let result = parse_callback_from_request(request);
209 assert!(result.is_err());
210 let err = result.unwrap_err().to_string();
211 assert!(err.contains("User denied"));
212 }
213
214 #[test]
215 fn parse_callback_returns_error_for_missing_code() {
216 let request = "GET /oauth2callback?state=verifier HTTP/1.1\r\n";
217 let result = parse_callback_from_request(request);
218 assert!(result.is_err());
219 assert!(result.unwrap_err().to_string().contains("No authorization code"));
220 }
221
222 #[test]
223 fn parse_callback_returns_error_for_missing_state() {
224 let request = "GET /oauth2callback?code=abc123 HTTP/1.1\r\n";
225 let result = parse_callback_from_request(request);
226 assert!(result.is_err());
227 assert!(result.unwrap_err().to_string().contains("No state parameter"));
228 }
229
230 #[tokio::test]
231 async fn with_redirect_uri_binds_to_specified_port() {
232 let handler = BrowserOAuthHandler::with_redirect_uri("http://localhost:9999/callback", 0).unwrap();
233 assert_eq!(handler.redirect_uri(), "http://localhost:9999/callback");
234 }
235
236 #[test]
237 fn urlencoding_decode_handles_percent() {
238 assert_eq!(urlencoding_decode("hello%20world"), "hello world");
239 assert_eq!(urlencoding_decode("a%2Fb"), "a/b");
240 }
241
242 #[test]
243 fn urlencoding_decode_handles_plus() {
244 assert_eq!(urlencoding_decode("hello+world"), "hello world");
245 }
246}