tauri 2.11.3

Make tiny, secure apps for all desktop platforms with Tauri
Documentation
// Copyright 2019-2024 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT

use std::{borrow::Cow, sync::Arc};

use http::{header::CONTENT_TYPE, Request, Response as HttpResponse, StatusCode};
use tauri_utils::config::HeaderAddition;

use crate::{
  manager::{webview::PROXY_DEV_SERVER, AppManager},
  webview::{UriSchemeProtocolHandler, WebResourceRequestHandler},
  Runtime,
};

#[cfg(all(dev, mobile))]
use std::{collections::HashMap, sync::Mutex};

#[cfg(all(dev, mobile))]
#[derive(Clone)]
struct CachedResponse {
  status: http::StatusCode,
  headers: http::HeaderMap,
  body: Vec<u8>,
}

pub fn get<R: Runtime>(
  manager: Arc<AppManager<R>>,
  window_origin: String,
  web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,
) -> UriSchemeProtocolHandler {
  #[cfg(all(dev, mobile))]
  let (url, client, response_cache) = {
    let use_https = window_origin.starts_with("https");
    let mut url = manager.get_app_url(use_https).as_str().to_string();
    if url.ends_with('/') {
      url.pop();
    }

    #[allow(unused_mut)]
    let mut client_builder = reqwest::ClientBuilder::new();
    if use_https {
      #[cfg(feature = "rustls-tls")]
      if rustls::crypto::CryptoProvider::get_default().is_none() {
        let _ = rustls::crypto::ring::default_provider().install_default();
      }

      // we can't load env vars at runtime, gotta embed them in the lib
      #[allow(unused_variables)]
      if let Some(cert_pem) = option_env!("TAURI_DEV_ROOT_CERTIFICATE") {
        #[cfg(any(
          feature = "native-tls",
          feature = "native-tls-vendored",
          feature = "rustls-tls"
        ))]
        {
          log::info!("adding dev server root certificate");
          let certificate = reqwest::Certificate::from_pem(cert_pem.as_bytes())
            .expect("failed to parse TAURI_DEV_ROOT_CERTIFICATE");
          client_builder = client_builder.tls_certs_merge([certificate]);
        }

        #[cfg(not(any(
          feature = "native-tls",
          feature = "native-tls-vendored",
          feature = "rustls-tls"
        )))]
        {
          log::warn!(
            "the dev root-certificate-path option was provided, but you must enable one of the following Tauri features in Cargo.toml: native-tls, native-tls-vendored, rustls-tls"
          );
        }
      } else {
        log::warn!(
          "loading HTTPS URL; you might need to provide a certificate via the `dev --root-certificate-path` option. You must enable one of the following Tauri features in Cargo.toml: native-tls, native-tls-vendored, rustls-tls"
        );
      }
    }
    let client = client_builder.build().unwrap();

    let response_cache = Mutex::new(HashMap::new());

    (url, client, response_cache)
  };

  let context = Arc::new(Context {
    manager,
    web_resource_request_handler,
    window_origin,
    #[cfg(all(dev, mobile))]
    client,
    #[cfg(all(dev, mobile))]
    url,
    #[cfg(all(dev, mobile))]
    response_cache,
  });

  Box::new(move |_, request, responder| {
    let context = context.clone();
    crate::async_runtime::spawn(async move {
      match get_response(&context, request).await {
        Ok(response) => responder.respond(response),
        Err(e) => responder.respond(
          HttpResponse::builder()
            .status(StatusCode::INTERNAL_SERVER_ERROR)
            .header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
            .header("Access-Control-Allow-Origin", &context.window_origin)
            .body(e.to_string().into_bytes())
            .unwrap(),
        ),
      }
    });
  })
}

struct Context<R: Runtime> {
  manager: Arc<AppManager<R>>,
  window_origin: String,
  web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,

  #[cfg(all(dev, mobile))]
  url: String,
  #[cfg(all(dev, mobile))]
  client: reqwest::Client,
  #[cfg(all(dev, mobile))]
  response_cache: Mutex<HashMap<String, CachedResponse>>,
}

async fn get_response<R: Runtime>(
  context: &Context<R>,
  request: Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
  let Context {
    manager,
    web_resource_request_handler,
    window_origin,
    #[cfg(all(dev, mobile))]
    client,
    #[cfg(all(dev, mobile))]
    url,
    #[cfg(all(dev, mobile))]
    response_cache,
  } = context;

  // use the entire URI as we are going to proxy the request
  let path = if PROXY_DEV_SERVER {
    request.uri().to_string()
  } else {
    // ignore query string and fragment
    request
      .uri()
      .to_string()
      .split(&['?', '#'])
      .next()
      .unwrap()
      .into()
  };

  let path = path
    .strip_prefix("tauri://localhost")
    .map(|p| p.to_string())
    // the `strip_prefix` only returns None when a request is made to `https://tauri.$P` on Windows and Android
    // where `$P` is not `localhost/*`
    .unwrap_or_default();

  #[allow(unused_mut)]
  let mut builder = HttpResponse::builder()
    .add_configured_headers(manager.config.app.security.headers.as_ref())
    .header("Access-Control-Allow-Origin", window_origin);

  #[cfg(all(dev, mobile))]
  let mut response =
    proxy_dev_request(client, url, response_cache, path, builder, &request).await?;

  #[cfg(not(all(dev, mobile)))]
  let mut response = {
    let asset = manager.get_asset(
      path,
      request.uri().scheme() == Some(&http::uri::Scheme::HTTPS),
    )?;
    builder = builder.header(CONTENT_TYPE, &asset.mime_type);
    if let Some(csp) = &asset.csp_header {
      builder = builder.header("Content-Security-Policy", csp);
    }
    builder.body(asset.bytes.into())?
  };

  if let Some(handler) = web_resource_request_handler {
    handler(request, &mut response);
  }

  Ok(response)
}

#[cfg(all(dev, mobile))]
async fn proxy_dev_request(
  client: &reqwest::Client,
  url: &String,
  response_cache: &Mutex<HashMap<String, CachedResponse>>,
  path: String,
  mut builder: http::response::Builder,
  request: &Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
  let decoded_path = percent_encoding::percent_decode(path.as_bytes())
    .decode_utf8_lossy()
    .to_string();
  let url = format!(
    "{}/{}",
    url.trim_end_matches('/'),
    decoded_path.trim_start_matches('/')
  );

  let mut proxy_builder = client.request(request.method().clone(), &url);
  for (name, value) in request.headers() {
    proxy_builder = proxy_builder.header(name, value);
  }
  proxy_builder = proxy_builder.body(request.body().clone());

  let response = proxy_builder.send().await.map_err(|e|{
    let error_message = format!(
      "Failed to request {url}: {e}{}",
      if let Some(s) = e.status() {
        format!("status code: {}", s.as_u16())
      } else if cfg!(target_os = "ios") {
        ", did you grant local network permissions? That is required to reach the development server. Please grant the permission via the prompt or in `Settings > Privacy & Security > Local Network` and restart the app. See https://support.apple.com/en-us/102229 for more information.".to_string()
      } else {
        "".to_string()
      }
    );
    log::error!("{error_message}");
    error_message
  })?;

  let status = response.status();

  if status == http::StatusCode::NOT_MODIFIED {
    if let Some(response) = response_cache.lock().unwrap().get(&url).cloned() {
      for (name, value) in &response.headers {
        builder = builder.header(name, value);
      }

      return Ok(builder.status(response.status).body(response.body.into())?);
    }
  }

  let headers = response.headers().clone();
  let body = response.bytes().await?.to_vec();
  let response = CachedResponse {
    status,
    headers,
    body,
  };

  response_cache
    .lock()
    .unwrap()
    .insert(url.clone(), response.clone());

  for (name, value) in &response.headers {
    builder = builder.header(name, value);
  }

  builder
    .status(response.status)
    .body(response.body.into())
    .map_err(Into::into)
}