use crate::{servers::HTTP_PORT, MIN_SERVER_VERSION};
use hyper::{
header::{self, HeaderName, HeaderValue},
Body, HeaderMap, Response,
};
use log::error;
use reqwest::{Client, Identity, Upgraded};
use semver::Version;
use serde::{Deserialize, Serialize};
use std::{path::Path, str::FromStr};
use thiserror::Error;
use url::Url;
pub const DETAILS_ENDPOINT: &str = "api/server";
pub const TELEMETRY_ENDPOINT: &str = "api/server/telemetry";
pub const UPGRADE_ENDPOINT: &str = "api/server/upgrade";
pub const TUNNEL_ENDPOINT: &str = "api/server/tunnel";
pub const SERVER_IDENT: &str = "POCKET_RELAY_SERVER";
pub const USER_AGENT: &str = concat!("PocketRelayClient/v", env!("CARGO_PKG_VERSION"));
mod headers {
pub const ASSOCIATION: &str = "x-association";
pub const LEGACY_SCHEME: &str = "x-pocket-relay-scheme";
pub const LEGACY_HOST: &str = "x-pocket-relay-host";
pub const LEGACY_PORT: &str = "x-pocket-relay-port";
pub const LEGACY_LOCAL_HTTP: &str = "x-pocket-relay-local-http";
}
pub fn create_http_client(identity: Option<Identity>) -> Result<Client, reqwest::Error> {
let mut builder = Client::builder().user_agent(USER_AGENT);
if let Some(identity) = identity {
builder = builder.identity(identity);
}
builder.build()
}
#[derive(Debug, Error)]
pub enum ClientIdentityError {
#[error("Failed to read identity: {0}")]
Read(#[from] std::io::Error),
#[error("Failed to create identity: {0}")]
Create(#[from] reqwest::Error),
}
pub fn read_client_identity(path: &Path) -> Result<Identity, ClientIdentityError> {
let bytes = std::fs::read(path).map_err(ClientIdentityError::Read)?;
Identity::from_pkcs12_der(&bytes, "").map_err(ClientIdentityError::Create)
}
#[derive(Deserialize)]
struct ServerDetails {
version: Version,
#[serde(default)]
ident: Option<String>,
association: Option<String>,
}
#[derive(Debug, Clone)]
pub struct LookupData {
pub url: Url,
pub version: Version,
pub association: Option<String>,
}
#[derive(Debug, Error)]
pub enum LookupError {
#[error("Invalid Connection URL: {0}")]
InvalidHostTarget(#[from] url::ParseError),
#[error("Failed to connect to server: {0}")]
ConnectionFailed(reqwest::Error),
#[error("Server replied with error response: {0}")]
ErrorResponse(reqwest::Error),
#[error("Invalid server response: {0}")]
InvalidResponse(reqwest::Error),
#[error("Server identifier was incorrect (Not a PocketRelay server?)")]
NotPocketRelay,
#[error("Server version is too outdated ({0}) this client requires servers of version {1} or greater")]
ServerOutdated(Version, Version),
}
pub async fn lookup_server(
http_client: reqwest::Client,
host: String,
) -> Result<LookupData, LookupError> {
let mut url = String::new();
let mut inferred_scheme = false;
if !host.starts_with("http://") && !host.starts_with("https://") {
url.push_str("http://");
inferred_scheme = true;
}
url.push_str(&host);
if !url.ends_with('/') {
url.push('/');
}
let mut url = Url::from_str(&url)?;
if url.port().is_some_and(|port| port == 443) && inferred_scheme {
let _ = url.set_scheme("https");
}
let info_url = url
.join(DETAILS_ENDPOINT)
.expect("Failed to create server details URL");
let response = http_client
.get(info_url)
.header(header::ACCEPT, "application/json")
.send()
.await
.map_err(LookupError::ConnectionFailed)?;
#[cfg(debug_assertions)]
{
use log::debug;
debug!("Response Status: {}", response.status());
debug!("HTTP Version: {:?}", response.version());
debug!("Content Length: {:?}", response.content_length());
debug!("HTTP Headers: {:?}", response.headers());
}
let response = response
.error_for_status()
.map_err(LookupError::ErrorResponse)?;
let details = response
.json::<ServerDetails>()
.await
.map_err(LookupError::InvalidResponse)?;
if details.ident.is_none() || details.ident.is_some_and(|value| value != SERVER_IDENT) {
return Err(LookupError::NotPocketRelay);
}
if details.version < MIN_SERVER_VERSION {
return Err(LookupError::ServerOutdated(
details.version,
MIN_SERVER_VERSION,
));
}
#[cfg(debug_assertions)]
{
use log::debug;
if let Some(association) = &details.association {
debug!("Acquired association token: {}", association);
}
}
Ok(LookupData {
url,
version: details.version,
association: details.association,
})
}
#[derive(Debug, Error)]
pub enum ServerStreamError {
#[error("Request failed: {0}")]
RequestFailed(reqwest::Error),
#[error("Server error response: {0}")]
ServerError(reqwest::Error),
#[error("Upgrade failed: {0}")]
UpgradeFailure(reqwest::Error),
}
pub async fn create_server_stream(
http_client: &reqwest::Client,
base_url: &Url,
association: Option<&String>,
) -> Result<Upgraded, ServerStreamError> {
let endpoint_url: Url = base_url
.join(UPGRADE_ENDPOINT)
.expect("Failed to create upgrade endpoint");
let mut headers: HeaderMap<HeaderValue> = [
(header::CONNECTION, HeaderValue::from_static("Upgrade")),
(header::UPGRADE, HeaderValue::from_static("blaze")),
(
HeaderName::from_static(headers::LEGACY_SCHEME),
HeaderValue::from_static("http"),
),
(
HeaderName::from_static(headers::LEGACY_HOST),
HeaderValue::from_static("127.0.0.1"),
),
(
HeaderName::from_static(headers::LEGACY_PORT),
HeaderValue::from(HTTP_PORT),
),
(
HeaderName::from_static(headers::LEGACY_LOCAL_HTTP),
HeaderValue::from_static("true"),
),
]
.into_iter()
.collect();
if let Some(association) = association {
headers.insert(
HeaderName::from_static(headers::ASSOCIATION),
HeaderValue::from_str(association).expect("Invalid association token"),
);
}
let response = http_client
.get(endpoint_url)
.headers(headers)
.send()
.await
.map_err(ServerStreamError::RequestFailed)?;
let response = response
.error_for_status()
.map_err(ServerStreamError::ServerError)?;
response
.upgrade()
.await
.map_err(ServerStreamError::UpgradeFailure)
}
#[derive(Serialize)]
pub struct TelemetryEvent {
pub values: Vec<(String, String)>,
}
pub async fn publish_telemetry_event(
http_client: &reqwest::Client,
base_url: &Url,
event: TelemetryEvent,
) -> Result<(), reqwest::Error> {
let endpoint_url: Url = base_url
.join(TELEMETRY_ENDPOINT)
.expect("Failed to create telemetry endpoint");
let response = http_client.post(endpoint_url).json(&event).send().await?;
let _ = response.error_for_status()?;
Ok(())
}
#[derive(Debug, Error)]
pub enum ProxyError {
#[error("Request failed: {0}")]
RequestFailed(reqwest::Error),
#[error("Request failed: {0}")]
BodyFailed(reqwest::Error),
}
pub async fn proxy_http_request(
http_client: &reqwest::Client,
url: Url,
) -> Result<Response<Body>, ProxyError> {
let response = http_client
.get(url)
.send()
.await
.map_err(ProxyError::RequestFailed)?;
let status = response.status();
let headers = response.headers().clone();
let body: bytes::Bytes = response.bytes().await.map_err(ProxyError::BodyFailed)?;
let mut response = Response::new(Body::from(body));
*response.status_mut() = status;
*response.headers_mut() = headers;
Ok(response)
}
pub async fn create_server_tunnel(
http_client: &reqwest::Client,
base_url: &Url,
association: &str,
) -> Result<Upgraded, ServerStreamError> {
let endpoint_url: Url = base_url
.join(TUNNEL_ENDPOINT)
.expect("Failed to create tunnel endpoint");
let mut headers: HeaderMap<HeaderValue> = [
(header::CONNECTION, HeaderValue::from_static("Upgrade")),
(header::UPGRADE, HeaderValue::from_static("tunnel")),
]
.into_iter()
.collect();
headers.insert(
HeaderName::from_static(headers::ASSOCIATION),
HeaderValue::from_str(association).expect("Invalid association token"),
);
let response = http_client
.get(endpoint_url)
.headers(headers)
.send()
.await
.map_err(ServerStreamError::RequestFailed)?;
let response = response
.error_for_status()
.map_err(ServerStreamError::ServerError)?;
response
.upgrade()
.await
.map_err(ServerStreamError::UpgradeFailure)
}