use async_channel::Receiver;
use axum::{
Router,
extract::{FromRef, Query, State},
response::IntoResponse,
routing::get,
};
use axum_server::tls_rustls::RustlsConfig;
use oauth2::CsrfToken;
use std::{net::SocketAddr, path::Path, result::Result};
use url::Url;
use super::{AuthContext, ChannelMessenger};
use crate::{error::Error, token::auth::AuthRequest};
#[derive(Debug)]
pub struct LocalServerMessenger {
config: RustlsConfig,
addr: Option<SocketAddr>,
rx: Option<Receiver<String>>,
app_state: Option<AppState>,
auth_url: Option<Url>,
}
impl LocalServerMessenger {
pub async fn new(certs_dir: &Path) -> Self {
Self {
config: RustlsConfig::from_pem_file(
certs_dir.join("cert.pem"),
certs_dir.join("key.pem"),
)
.await
.expect("certs setting ok"),
addr: None,
rx: None,
app_state: None,
auth_url: None,
}
}
}
impl ChannelMessenger for LocalServerMessenger {
async fn with_context(&mut self, context: AuthContext) -> Result<(), Error> {
let (tx, rx) = async_channel::unbounded();
let csrf = context
.csrf
.as_ref()
.ok_or(Error::ChannelMessenger("No CSRF".to_string()))?
.clone();
let redirect_uri = context
.redirect_url
.as_ref()
.ok_or(Error::ChannelMessenger("No redirect_url".to_string()))?;
self.app_state = Some(AppState { csrf, tx });
self.rx = Some(rx);
self.auth_url = Some(
context
.auth_url
.as_ref()
.ok_or(Error::ChannelMessenger("No auth_url".to_string()))?
.clone(),
);
self.addr = Some(parse_socket_addr(redirect_uri).map_err(Error::ChannelMessenger)?);
Ok(())
}
async fn send_auth_message(&self) -> Result<(), Error> {
open::that(
self.auth_url
.as_ref()
.ok_or(Error::ChannelMessenger("No auth_url".to_string()))?
.as_ref(),
)?;
Ok(())
}
async fn receive_auth_message(&self) -> Result<String, Error> {
let service = app(self
.app_state
.as_ref()
.ok_or(Error::ChannelMessenger("No app_state".to_string()))?
.clone())
.into_make_service();
tokio::spawn(
axum_server::bind_rustls(
*self
.addr
.as_ref()
.ok_or(Error::ChannelMessenger("No SocketAddr".to_string()))?,
self.config.clone(),
)
.serve(service),
);
let code = self
.rx
.as_ref()
.ok_or(Error::ChannelMessenger("No rx".to_string()))?
.recv()
.await
.map_err(|e| Error::ChannelMessenger(format!("{e:?}")))?;
Ok(code)
}
}
fn app(app_state: AppState) -> Router {
Router::new()
.route("/", get(get_code))
.with_state(app_state)
}
#[derive(Debug, Clone, FromRef)]
struct AppState {
csrf: CsrfToken,
tx: async_channel::Sender<String>,
}
async fn get_code(
Query(query): Query<AuthRequest>,
State(csrf): State<CsrfToken>,
State(tx): State<async_channel::Sender<String>>,
) -> impl IntoResponse {
if &query.state != csrf.secret() {
return "CSRF check error".to_string();
}
let content = format!(
"Schwab returned the following code:\n{}\nYou can now safely close this browser window.",
&query.code
);
tx.send(query.code).await.expect("send code");
content
}
fn parse_socket_addr(url: &Url) -> Result<SocketAddr, String> {
let Some(hostname) = url.host_str() else {
return Err("No hostname found in URL".to_string());
};
let port = url.port().unwrap_or(443);
let addr = format!("{hostname}:{port}");
match addr.parse::<SocketAddr>() {
Ok(addr) => Ok(addr),
Err(err) => Err(format!("Failed to parse socket address: {err}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode, Uri},
};
use pretty_assertions::assert_eq;
use std::path::PathBuf;
use tower::ServiceExt;
fn config(csrf: CsrfToken, tx: async_channel::Sender<String>) -> AppState {
AppState { csrf, tx }
}
#[test]
fn test_parse_socket_addr() {
let expected_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
let addr = parse_socket_addr(&"https://127.0.0.1:8080".parse().unwrap()).unwrap();
assert_eq!(addr, expected_addr);
let expected_addr = SocketAddr::from(([127, 0, 0, 1], 443));
let addr = parse_socket_addr(&"https://127.0.0.1".parse().unwrap()).unwrap();
assert_eq!(addr, expected_addr);
let err = parse_socket_addr(&"https:///path".parse().unwrap()).unwrap_err();
assert_eq!(
err,
"Failed to parse socket address: invalid socket address syntax"
);
let expected_addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let addr = parse_socket_addr(&"https://127.0.0.1:3000".parse().unwrap()).unwrap();
assert_eq!(addr, expected_addr);
let expected_addr = SocketAddr::from(([192, 168, 1, 1], 8080));
let addr = parse_socket_addr(&"https://192.168.1.1:8080".parse().unwrap()).unwrap();
assert_eq!(addr, expected_addr);
let addr = parse_socket_addr(&"http://example.com:80".parse().unwrap()).unwrap_err();
assert_eq!(
addr,
"Failed to parse socket address: invalid socket address syntax"
);
}
#[tokio::test]
async fn test_router() {
let (tx, rx) = async_channel::unbounded();
let csrf = CsrfToken::new_random();
let uri = Uri::builder()
.path_and_query(format!("/?code=code&state={}", csrf.secret()))
.build()
.unwrap();
let response = app(config(csrf, tx))
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(
String::from_utf8(bytes.to_vec()).unwrap(),
"Schwab returned the following code:\ncode\nYou can now safely close this browser window."
);
assert_eq!(rx.recv().await.unwrap(), "code");
}
#[tokio::test]
#[ignore = "Testing manually for browser verification. Should be --nocapture"]
async fn test_local_server_messenger() {
let context = AuthContext {
auth_url: Some(
"https://127.0.0.1:8081/?state=CSRF&code=code"
.parse()
.unwrap(),
),
csrf: Some(CsrfToken::new("CSRF".to_string())),
redirect_url: Some("https://127.0.0.1:8081".parse().unwrap()),
};
let certs_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs");
let mut messenger = LocalServerMessenger::new(&certs_dir).await;
messenger.with_context(context).await.unwrap();
messenger.send_auth_message().await.unwrap();
assert_eq!("code", messenger.receive_auth_message().await.unwrap());
}
}