use std::sync::Arc;
use axum::extract::{Path, State};
use axum::Json;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tauri::Runtime;
use crate::server::response::{WebDriverErrorResponse, WebDriverResponse, WebDriverResult};
use crate::server::AppState;
use crate::webdriver::Timeouts;
async fn wait_for_window<R: Runtime>(
state: &AppState<R>,
timeout_ms: u64,
target_label: Option<&str>,
) -> Result<String, WebDriverErrorResponse> {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
let poll_interval = std::time::Duration::from_millis(100);
loop {
let window_labels = state.get_window_labels();
if let Some(label) = target_label {
if window_labels.contains(&label.to_string()) {
return Ok(label.to_string());
}
} else if let Some(label) = window_labels.first().cloned() {
return Ok(label);
}
if start.elapsed() >= timeout {
return Err(WebDriverErrorResponse::no_such_window());
}
tokio::time::sleep(poll_interval).await;
}
}
fn extract_window_label(capabilities: &serde_json::Value) -> Option<String> {
if let Some(window_label) = capabilities
.get("alwaysMatch")
.and_then(|v| v.get("wdio:tauriServiceOptions"))
.and_then(|v| v.get("windowLabel"))
.and_then(|v| v.as_str())
{
return Some(window_label.to_string());
}
if let Some(arr) = capabilities.get("firstMatch").and_then(|v| v.as_array()) {
for item in arr {
if let Some(window_label) = item
.get("wdio:tauriServiceOptions")
.and_then(|v| v.get("windowLabel"))
.and_then(|v| v.as_str())
{
return Some(window_label.to_string());
}
}
}
None
}
#[derive(Debug, Deserialize)]
pub struct CreateSessionRequest {
#[allow(dead_code)] pub capabilities: Value,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionResponse {
pub session_id: String,
pub capabilities: Value,
}
fn parse_user_agent(user_agent: &str) -> (String, String) {
if user_agent.contains("Edg/") {
let version = user_agent
.split("Edg/")
.nth(1)
.and_then(|s| s.split_whitespace().next())
.unwrap_or("unknown");
return ("msedge".to_string(), version.to_string());
}
if user_agent.contains("Android") {
let version = user_agent
.split("Chrome/")
.nth(1)
.and_then(|s| s.split_whitespace().next())
.unwrap_or("unknown");
return ("chrome".to_string(), version.to_string());
}
if user_agent.contains("Linux") || user_agent.contains("X11") {
let version = user_agent
.split("AppleWebKit/")
.nth(1)
.and_then(|s| s.split_whitespace().next())
.unwrap_or("unknown");
return ("WebKitGTK".to_string(), version.to_string());
}
if (user_agent.contains("iPhone") || user_agent.contains("iPad") || user_agent.contains("iPod"))
&& user_agent.contains("AppleWebKit/")
{
let version = user_agent
.split("AppleWebKit/")
.nth(1)
.and_then(|s| s.split_whitespace().next())
.and_then(|s| s.split('(').next()) .unwrap_or("unknown");
return ("webkit".to_string(), version.to_string());
}
if user_agent.contains("Macintosh") && user_agent.contains("AppleWebKit/") {
let version = user_agent
.split("AppleWebKit/")
.nth(1)
.and_then(|s| s.split_whitespace().next())
.and_then(|s| s.split('(').next()) .unwrap_or("unknown");
return ("webkit".to_string(), version.to_string());
}
("webview".to_string(), "unknown".to_string())
}
pub async fn create<R: Runtime + 'static>(
State(state): State<Arc<AppState<R>>>,
Json(request): Json<CreateSessionRequest>,
) -> WebDriverResult {
let target_window = extract_window_label(&request.capabilities);
let initial_window = wait_for_window(&state, 10_000, target_window.as_deref()).await?;
let executor =
state.get_executor_for_window(&initial_window, Timeouts::default(), Vec::new())?;
let user_agent_result = executor
.evaluate_js("(function() { return navigator.userAgent; })()")
.await;
let (browser_name, browser_version) = match user_agent_result {
Ok(result) => {
let user_agent = result.get("value").and_then(|v| v.as_str()).unwrap_or("");
parse_user_agent(user_agent)
}
Err(_) => ("webview".to_string(), "unknown".to_string()),
};
let mut sessions = state.sessions.write().await;
let session = sessions.create(initial_window);
#[cfg(mobile)]
let set_window_rect = false;
#[cfg(desktop)]
let set_window_rect = true;
let response = SessionResponse {
session_id: session.id.clone(),
capabilities: json!({
"browserName": browser_name,
"browserVersion": browser_version,
"platformName": std::env::consts::OS,
"acceptInsecureCerts": false,
"pageLoadStrategy": "normal",
"setWindowRect": set_window_rect,
"timeouts": {
"implicit": session.timeouts.implicit_ms,
"pageLoad": session.timeouts.page_load_ms,
"script": session.timeouts.script_ms
}
}),
};
Ok(WebDriverResponse::success(response))
}
pub async fn delete<R: Runtime>(
State(state): State<Arc<AppState<R>>>,
Path(session_id): Path<String>,
) -> WebDriverResult {
let mut sessions = state.sessions.write().await;
if sessions.delete(&session_id) {
Ok(WebDriverResponse::null())
} else {
Err(WebDriverErrorResponse::invalid_session_id(&session_id))
}
}