use std::{borrow::Cow, sync::Arc};
use http::{header::CONTENT_TYPE, Request, Response as HttpResponse, StatusCode};
use tauri_utils::config::HeaderAddition;
use crate::{
manager::{webview::PROXY_DEV_SERVER, AppManager},
webview::{UriSchemeProtocolHandler, WebResourceRequestHandler},
Runtime,
};
#[cfg(all(dev, mobile))]
use std::{collections::HashMap, sync::Mutex};
#[cfg(all(dev, mobile))]
#[derive(Clone)]
struct CachedResponse {
status: http::StatusCode,
headers: http::HeaderMap,
body: Vec<u8>,
}
pub fn get<R: Runtime>(
manager: Arc<AppManager<R>>,
window_origin: String,
web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,
) -> UriSchemeProtocolHandler {
#[cfg(all(dev, mobile))]
let (url, client, response_cache) = {
let use_https = window_origin.starts_with("https");
let mut url = manager.get_app_url(use_https).as_str().to_string();
if url.ends_with('/') {
url.pop();
}
#[allow(unused_mut)]
let mut client_builder = reqwest::ClientBuilder::new();
if use_https {
#[cfg(feature = "rustls-tls")]
if rustls::crypto::CryptoProvider::get_default().is_none() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[allow(unused_variables)]
if let Some(cert_pem) = option_env!("TAURI_DEV_ROOT_CERTIFICATE") {
#[cfg(any(
feature = "native-tls",
feature = "native-tls-vendored",
feature = "rustls-tls"
))]
{
log::info!("adding dev server root certificate");
let certificate = reqwest::Certificate::from_pem(cert_pem.as_bytes())
.expect("failed to parse TAURI_DEV_ROOT_CERTIFICATE");
client_builder = client_builder.tls_certs_merge([certificate]);
}
#[cfg(not(any(
feature = "native-tls",
feature = "native-tls-vendored",
feature = "rustls-tls"
)))]
{
log::warn!(
"the dev root-certificate-path option was provided, but you must enable one of the following Tauri features in Cargo.toml: native-tls, native-tls-vendored, rustls-tls"
);
}
} else {
log::warn!(
"loading HTTPS URL; you might need to provide a certificate via the `dev --root-certificate-path` option. You must enable one of the following Tauri features in Cargo.toml: native-tls, native-tls-vendored, rustls-tls"
);
}
}
let client = client_builder.build().unwrap();
let response_cache = Mutex::new(HashMap::new());
(url, client, response_cache)
};
let context = Arc::new(Context {
manager,
web_resource_request_handler,
window_origin,
#[cfg(all(dev, mobile))]
client,
#[cfg(all(dev, mobile))]
url,
#[cfg(all(dev, mobile))]
response_cache,
});
Box::new(move |_, request, responder| {
let context = context.clone();
crate::async_runtime::spawn(async move {
match get_response(&context, request).await {
Ok(response) => responder.respond(response),
Err(e) => responder.respond(
HttpResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
.header("Access-Control-Allow-Origin", &context.window_origin)
.body(e.to_string().into_bytes())
.unwrap(),
),
}
});
})
}
struct Context<R: Runtime> {
manager: Arc<AppManager<R>>,
window_origin: String,
web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,
#[cfg(all(dev, mobile))]
url: String,
#[cfg(all(dev, mobile))]
client: reqwest::Client,
#[cfg(all(dev, mobile))]
response_cache: Mutex<HashMap<String, CachedResponse>>,
}
async fn get_response<R: Runtime>(
context: &Context<R>,
request: Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
let Context {
manager,
web_resource_request_handler,
window_origin,
#[cfg(all(dev, mobile))]
client,
#[cfg(all(dev, mobile))]
url,
#[cfg(all(dev, mobile))]
response_cache,
} = context;
let path = if PROXY_DEV_SERVER {
request.uri().to_string()
} else {
request
.uri()
.to_string()
.split(&['?', '#'])
.next()
.unwrap()
.into()
};
let path = path
.strip_prefix("tauri://localhost")
.map(|p| p.to_string())
.unwrap_or_default();
#[allow(unused_mut)]
let mut builder = HttpResponse::builder()
.add_configured_headers(manager.config.app.security.headers.as_ref())
.header("Access-Control-Allow-Origin", window_origin);
#[cfg(all(dev, mobile))]
let mut response =
proxy_dev_request(client, url, response_cache, path, builder, &request).await?;
#[cfg(not(all(dev, mobile)))]
let mut response = {
let asset = manager.get_asset(
path,
request.uri().scheme() == Some(&http::uri::Scheme::HTTPS),
)?;
builder = builder.header(CONTENT_TYPE, &asset.mime_type);
if let Some(csp) = &asset.csp_header {
builder = builder.header("Content-Security-Policy", csp);
}
builder.body(asset.bytes.into())?
};
if let Some(handler) = web_resource_request_handler {
handler(request, &mut response);
}
Ok(response)
}
#[cfg(all(dev, mobile))]
async fn proxy_dev_request(
client: &reqwest::Client,
url: &String,
response_cache: &Mutex<HashMap<String, CachedResponse>>,
path: String,
mut builder: http::response::Builder,
request: &Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
let decoded_path = percent_encoding::percent_decode(path.as_bytes())
.decode_utf8_lossy()
.to_string();
let url = format!(
"{}/{}",
url.trim_end_matches('/'),
decoded_path.trim_start_matches('/')
);
let mut proxy_builder = client.request(request.method().clone(), &url);
for (name, value) in request.headers() {
proxy_builder = proxy_builder.header(name, value);
}
proxy_builder = proxy_builder.body(request.body().clone());
let response = proxy_builder.send().await.map_err(|e|{
let error_message = format!(
"Failed to request {url}: {e}{}",
if let Some(s) = e.status() {
format!("status code: {}", s.as_u16())
} else if cfg!(target_os = "ios") {
", did you grant local network permissions? That is required to reach the development server. Please grant the permission via the prompt or in `Settings > Privacy & Security > Local Network` and restart the app. See https://support.apple.com/en-us/102229 for more information.".to_string()
} else {
"".to_string()
}
);
log::error!("{error_message}");
error_message
})?;
let status = response.status();
if status == http::StatusCode::NOT_MODIFIED {
if let Some(response) = response_cache.lock().unwrap().get(&url).cloned() {
for (name, value) in &response.headers {
builder = builder.header(name, value);
}
return Ok(builder.status(response.status).body(response.body.into())?);
}
}
let headers = response.headers().clone();
let body = response.bytes().await?.to_vec();
let response = CachedResponse {
status,
headers,
body,
};
response_cache
.lock()
.unwrap()
.insert(url.clone(), response.clone());
for (name, value) in &response.headers {
builder = builder.header(name, value);
}
builder
.status(response.status)
.body(response.body.into())
.map_err(Into::into)
}