Skip to main content

slack_rs/oauth/
server.rs

1//! Local callback server for OAuth flow
2//!
3//! Runs a temporary HTTP server on localhost to receive the OAuth callback
4
5use super::types::OAuthError;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpListener;
10use tokio::time::{timeout, Duration};
11
12#[derive(Debug, Clone)]
13pub struct CallbackResult {
14    pub code: String,
15    #[allow(dead_code)]
16    pub state: String,
17}
18
19/// Run a local HTTP server to receive OAuth callback
20///
21/// Returns the authorization code and state received from the callback
22///
23/// # Arguments
24/// * `port` - Port to listen on (typically 3000)
25/// * `expected_state` - Expected state value for CSRF verification
26/// * `timeout_secs` - Timeout in seconds (default 300)
27pub async fn run_callback_server(
28    port: u16,
29    expected_state: String,
30    timeout_secs: u64,
31) -> Result<CallbackResult, OAuthError> {
32    let bind_addr = format!("127.0.0.1:{}", port);
33    let listener = TcpListener::bind(&bind_addr)
34        .await
35        .map_err(|e| OAuthError::ServerError(format!("Failed to bind to port {}: {}", port, e)))?;
36
37    let actual_port = listener.local_addr().map(|a| a.port()).unwrap_or(port);
38    println!(
39        "Listening for OAuth callback on http://127.0.0.1:{}",
40        actual_port
41    );
42
43    let result: Arc<Mutex<Option<Result<CallbackResult, OAuthError>>>> = Arc::new(Mutex::new(None));
44
45    let server_result = result.clone();
46    let server_task = async move {
47        loop {
48            let (mut socket, _) = match listener.accept().await {
49                Ok(conn) => conn,
50                Err(e) => {
51                    let mut res = server_result.lock().unwrap();
52                    *res = Some(Err(OAuthError::ServerError(format!(
53                        "Failed to accept connection: {}",
54                        e
55                    ))));
56                    break;
57                }
58            };
59
60            let mut buffer = vec![0; 4096];
61            let n = match socket.read(&mut buffer).await {
62                Ok(n) if n > 0 => n,
63                _ => continue,
64            };
65
66            let request = String::from_utf8_lossy(&buffer[..n]);
67
68            // Parse the request line
69            if let Some(first_line) = request.lines().next() {
70                if let Some(path_part) = first_line.split_whitespace().nth(1) {
71                    if let Some(query_start) = path_part.find('?') {
72                        let query = &path_part[query_start + 1..];
73                        let params = parse_query_string(query);
74
75                        let response = if let (Some(code), Some(state)) =
76                            (params.get("code"), params.get("state"))
77                        {
78                            // Verify state
79                            if state != &expected_state {
80                                let mut res = server_result.lock().unwrap();
81                                *res = Some(Err(OAuthError::StateMismatch {
82                                    expected: expected_state.clone(),
83                                    actual: state.clone(),
84                                }));
85                                create_error_response("State mismatch - possible CSRF attack")
86                            } else {
87                                let mut res = server_result.lock().unwrap();
88                                *res = Some(Ok(CallbackResult {
89                                    code: code.clone(),
90                                    state: state.clone(),
91                                }));
92                                create_success_response()
93                            }
94                        } else if let Some(error) = params.get("error") {
95                            let mut res = server_result.lock().unwrap();
96                            *res = Some(Err(OAuthError::SlackError(error.clone())));
97                            create_error_response(&format!("OAuth error: {}", error))
98                        } else {
99                            create_error_response("Missing required parameters")
100                        };
101
102                        let _ = socket.write_all(response.as_bytes()).await;
103                        let _ = socket.flush().await;
104                        break;
105                    }
106                }
107            }
108        }
109    };
110
111    // Run with timeout
112    match timeout(Duration::from_secs(timeout_secs), server_task).await {
113        Ok(_) => {
114            let res = result.lock().unwrap();
115            match res.as_ref() {
116                Some(Ok(callback_result)) => Ok(callback_result.clone()),
117                Some(Err(e)) => Err(format_oauth_error(e)),
118                None => Err(OAuthError::ServerError("No result received".to_string())),
119            }
120        }
121        Err(_) => Err(OAuthError::ServerError(format!(
122            "Timeout after {} seconds waiting for callback",
123            timeout_secs
124        ))),
125    }
126}
127
128/// Helper function to format OAuthError for re-creation
129fn format_oauth_error(err: &OAuthError) -> OAuthError {
130    match err {
131        OAuthError::ConfigError(msg) => OAuthError::ConfigError(msg.clone()),
132        OAuthError::NetworkError(msg) => OAuthError::NetworkError(msg.clone()),
133        OAuthError::HttpError(code, msg) => OAuthError::HttpError(*code, msg.clone()),
134        OAuthError::ParseError(msg) => OAuthError::ParseError(msg.clone()),
135        OAuthError::SlackError(msg) => OAuthError::SlackError(msg.clone()),
136        OAuthError::StateMismatch { expected, actual } => OAuthError::StateMismatch {
137            expected: expected.clone(),
138            actual: actual.clone(),
139        },
140        OAuthError::ServerError(msg) => OAuthError::ServerError(msg.clone()),
141        OAuthError::BrowserError(msg) => OAuthError::BrowserError(msg.clone()),
142    }
143}
144
145/// Parse URL query string into a HashMap
146fn parse_query_string(query: &str) -> HashMap<String, String> {
147    query
148        .split('&')
149        .filter_map(|pair| {
150            let mut parts = pair.split('=');
151            match (parts.next(), parts.next()) {
152                (Some(key), Some(value)) => Some((
153                    key.to_string(),
154                    urlencoding::decode(value).ok()?.to_string(),
155                )),
156                _ => None,
157            }
158        })
159        .collect()
160}
161
162fn create_success_response() -> String {
163    "HTTP/1.1 200 OK\r\n\
164     Content-Type: text/html; charset=utf-8\r\n\
165     Connection: close\r\n\
166     \r\n\
167     <html>\
168     <head><title>Authentication Successful</title></head>\
169     <body>\
170     <h1>✓ Authentication Successful</h1>\
171     <p>You can close this window and return to the CLI.</p>\
172     </body>\
173     </html>"
174        .to_string()
175}
176
177fn create_error_response(message: &str) -> String {
178    format!(
179        "HTTP/1.1 400 Bad Request\r\n\
180         Content-Type: text/html; charset=utf-8\r\n\
181         Connection: close\r\n\
182         \r\n\
183         <html>\
184         <head><title>Authentication Failed</title></head>\
185         <body>\
186         <h1>✗ Authentication Failed</h1>\
187         <p>{}</p>\
188         </body>\
189         </html>",
190        message
191    )
192}
193
194// Note: urlencoding is used for URL decoding
195// We need to add this dependency
196mod urlencoding {
197    pub fn decode(s: &str) -> Result<String, ()> {
198        // Simple URL decode implementation
199        let mut result = String::new();
200        let mut chars = s.chars();
201        while let Some(c) = chars.next() {
202            match c {
203                '%' => {
204                    let hex: String = chars.by_ref().take(2).collect();
205                    if hex.len() == 2 {
206                        if let Ok(byte) = u8::from_str_radix(&hex, 16) {
207                            result.push(byte as char);
208                        } else {
209                            return Err(());
210                        }
211                    } else {
212                        return Err(());
213                    }
214                }
215                '+' => result.push(' '),
216                c => result.push(c),
217            }
218        }
219        Ok(result)
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_parse_query_string() {
229        let query = "code=test_code&state=test_state&foo=bar";
230        let params = parse_query_string(query);
231
232        assert_eq!(params.get("code"), Some(&"test_code".to_string()));
233        assert_eq!(params.get("state"), Some(&"test_state".to_string()));
234        assert_eq!(params.get("foo"), Some(&"bar".to_string()));
235    }
236
237    #[test]
238    fn test_parse_query_string_with_encoding() {
239        let query = "message=hello+world&name=test%20user";
240        let params = parse_query_string(query);
241
242        assert_eq!(params.get("message"), Some(&"hello world".to_string()));
243        assert_eq!(params.get("name"), Some(&"test user".to_string()));
244    }
245
246    #[tokio::test]
247    async fn test_callback_server_timeout() {
248        // Test that the server times out appropriately
249        let state = "test_state".to_string();
250        let result = run_callback_server(13579, state, 1).await;
251
252        assert!(result.is_err());
253        match result {
254            Err(OAuthError::ServerError(msg)) => {
255                assert!(msg.contains("Timeout"));
256            }
257            _ => panic!("Expected ServerError with timeout"),
258        }
259    }
260}