robomotion 0.1.3

Official Rust SDK for building Robomotion RPA packages
Documentation
//! OAuth2 dialog support with browser-based authorization flow.

use crate::runtime::{Result, RobomotionError};
use std::net::TcpListener;
use std::time::Duration;
use tokio::sync::mpsc;

/// Fixed OAuth2 redirect URL - must match what's configured in OAuth app.
pub const OAUTH2_REDIRECT_URL: &str = "http://localhost:9876/oauth2/callback";

/// OAuth2 callback port.
pub const OAUTH2_CALLBACK_PORT: u16 = 9876;

/// Port retry interval (5 seconds).
pub const OAUTH2_PORT_RETRY_INTERVAL: Duration = Duration::from_secs(5);

/// Maximum time to wait for port to become available (300 seconds).
pub const OAUTH2_PORT_MAX_TIMEOUT: Duration = Duration::from_secs(300);

/// Maximum time to wait for user to complete authorization (5 minutes).
pub const OAUTH2_AUTH_TIMEOUT: Duration = Duration::from_secs(300);

/// OAuth2 configuration for a provider.
#[derive(Clone)]
pub struct OAuth2Config {
    pub client_id: String,
    pub client_secret: String,
    pub auth_url: String,
    pub token_url: String,
    pub scopes: Vec<String>,
}

/// Open an OAuth2 dialog and wait for the callback.
///
/// This function:
/// 1. Starts a local HTTP server on port 9876
/// 2. Opens the browser to the authorization URL
/// 3. Waits for the OAuth callback with the authorization code
/// 4. Returns the authorization code
///
/// # Example
/// ```ignore
/// let config = OAuth2Config {
///     client_id: "your-client-id".to_string(),
///     client_secret: "your-client-secret".to_string(),
///     auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
///     token_url: "https://oauth2.googleapis.com/token".to_string(),
///     scopes: vec!["email".to_string(), "profile".to_string()],
/// };
///
/// let code = open_oauth_dialog(&config).await?;
/// // Exchange code for token...
/// ```
pub async fn open_oauth_dialog(config: &OAuth2Config) -> Result<String> {
    // Try to bind to the port with retry logic
    let listener = acquire_oauth_port().await?;

    let (tx, mut rx) = mpsc::channel::<std::result::Result<String, String>>(1);

    // Build the authorization URL
    let auth_url = format!(
        "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&access_type=offline&prompt=consent&state=state",
        config.auth_url,
        urlencoding::encode(&config.client_id),
        urlencoding::encode(OAUTH2_REDIRECT_URL),
        urlencoding::encode(&config.scopes.join(" "))
    );

    // Open browser
    if let Err(e) = webbrowser::open(&auth_url) {
        return Err(RobomotionError::OAuth(format!(
            "Failed to open browser: {}",
            e
        )));
    }

    // Start the callback server
    let tx_clone = tx.clone();
    let server_handle = tokio::spawn(async move {
        let listener = tokio::net::TcpListener::from_std(listener).unwrap();

        loop {
            match listener.accept().await {
                Ok((mut stream, _)) => {
                    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};

                    let mut reader = BufReader::new(&mut stream);
                    let mut request_line = String::new();

                    if reader.read_line(&mut request_line).await.is_ok() {
                        // Parse the request to get the code
                        if let Some(code) = extract_code_from_request(&request_line) {
                            // Send success response
                            let response = format!(
                                "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n{}",
                                SUCCESS_HTML
                            );
                            let _ = stream.write_all(response.as_bytes()).await;
                            let _ = tx_clone.send(Ok(code)).await;
                            break;
                        } else if let Some(error) = extract_error_from_request(&request_line) {
                            // Send error response
                            let response = format!(
                                "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n{}",
                                ERROR_HTML
                            );
                            let _ = stream.write_all(response.as_bytes()).await;
                            let _ = tx_clone.send(Err(error)).await;
                            break;
                        }
                    }
                }
                Err(_) => break,
            }
        }
    });

    // Wait for the callback with timeout
    let result = tokio::select! {
        result = rx.recv() => {
            match result {
                Some(Ok(code)) => Ok(code),
                Some(Err(e)) => Err(RobomotionError::OAuth(e)),
                None => Err(RobomotionError::OAuth("Channel closed".to_string())),
            }
        }
        _ = tokio::time::sleep(OAUTH2_AUTH_TIMEOUT) => {
            Err(RobomotionError::OAuth(format!(
                "OAuth authorization timed out after {:?} waiting for user to complete authorization",
                OAUTH2_AUTH_TIMEOUT
            )))
        }
    };

    // Clean up
    server_handle.abort();

    result
}

/// Acquire the OAuth port with retry logic.
async fn acquire_oauth_port() -> Result<std::net::TcpListener> {
    let start_time = std::time::Instant::now();
    let mut attempt = 0;

    loop {
        attempt += 1;

        match TcpListener::bind(format!("127.0.0.1:{}", OAUTH2_CALLBACK_PORT)) {
            Ok(listener) => {
                listener.set_nonblocking(true).ok();
                if attempt > 1 {
                    tracing::info!(
                        "OAuth: Successfully acquired port {} after {} attempts (elapsed: {:?})",
                        OAUTH2_CALLBACK_PORT,
                        attempt,
                        start_time.elapsed()
                    );
                }
                return Ok(listener);
            }
            Err(_) => {
                let elapsed = start_time.elapsed();
                if elapsed >= OAUTH2_PORT_MAX_TIMEOUT {
                    return Err(RobomotionError::OAuth(format!(
                        "Could not bind to port {} after {:?} ({} attempts). \
                         Port is in use by another process.",
                        OAUTH2_CALLBACK_PORT, OAUTH2_PORT_MAX_TIMEOUT, attempt
                    )));
                }

                if attempt == 1 {
                    tracing::info!(
                        "OAuth: Port {} is busy, waiting for it to become available...",
                        OAUTH2_CALLBACK_PORT
                    );
                }

                tokio::time::sleep(OAUTH2_PORT_RETRY_INTERVAL).await;
            }
        }
    }
}

/// Extract the authorization code from the request.
fn extract_code_from_request(request: &str) -> Option<String> {
    if request.contains("/oauth2/callback") {
        if let Some(query_start) = request.find('?') {
            let query = &request[query_start + 1..];
            if let Some(end) = query.find(' ') {
                let query = &query[..end];
                for param in query.split('&') {
                    let parts: Vec<&str> = param.splitn(2, '=').collect();
                    if parts.len() == 2 && parts[0] == "code" {
                        return Some(urlencoding::decode(parts[1]).ok()?.into_owned());
                    }
                }
            }
        }
    }
    None
}

/// Extract the error from the request.
fn extract_error_from_request(request: &str) -> Option<String> {
    if request.contains("/oauth2/callback") {
        if let Some(query_start) = request.find('?') {
            let query = &request[query_start + 1..];
            if let Some(end) = query.find(' ') {
                let query = &query[..end];
                let mut error = None;
                let mut error_desc = None;

                for param in query.split('&') {
                    let parts: Vec<&str> = param.splitn(2, '=').collect();
                    if parts.len() == 2 {
                        match parts[0] {
                            "error" => error = urlencoding::decode(parts[1]).ok().map(|s| s.into_owned()),
                            "error_description" => {
                                error_desc = urlencoding::decode(parts[1]).ok().map(|s| s.into_owned())
                            }
                            _ => {}
                        }
                    }
                }

                if let Some(err) = error {
                    return Some(match error_desc {
                        Some(desc) => format!("{}: {}", err, desc),
                        None => err,
                    });
                }
            }
        }
    }
    None
}

const SUCCESS_HTML: &str = r#"<!DOCTYPE html>
<html>
<head>
    <title>Authorization Successful</title>
    <style>
        body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f5f5f5; }
        .container { text-align: center; padding: 40px; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
        h1 { color: #4CAF50; margin-bottom: 16px; }
        p { color: #666; }
    </style>
</head>
<body>
    <div class="container">
        <h1>Authorization Successful</h1>
        <p>You can close this window and return to the application.</p>
    </div>
</body>
</html>"#;

const ERROR_HTML: &str = r#"<!DOCTYPE html>
<html>
<head>
    <title>Authorization Failed</title>
    <style>
        body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f5f5f5; }
        .container { text-align: center; padding: 40px; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
        h1 { color: #f44336; margin-bottom: 16px; }
        p { color: #666; }
    </style>
</head>
<body>
    <div class="container">
        <h1>Authorization Failed</h1>
        <p>An error occurred during authorization.</p>
        <p>You can close this window and try again.</p>
    </div>
</body>
</html>"#;