use std::io::{Read, Write};
use std::net::TcpListener;
use std::time::{Duration, Instant};
use url::Url;
use super::OAUTH_CALLBACK_TIMEOUT;
pub(super) fn bind_loopback_listener(redirect_uri: &str) -> Result<(TcpListener, String), String> {
let mut parsed =
Url::parse(redirect_uri).map_err(|error| format!("Invalid redirect URI: {error}"))?;
let host = parsed
.host_str()
.ok_or_else(|| "Redirect URI must include a host".to_string())?;
if host != "127.0.0.1" && host != "localhost" {
return Err("Redirect URI must bind to 127.0.0.1 or localhost".to_string());
}
let port = parsed
.port_or_known_default()
.ok_or_else(|| "Redirect URI must include a port".to_string())?;
let listener = TcpListener::bind((host, port))
.map_err(|error| format!("Failed to bind redirect URI {redirect_uri}: {error}"))?;
listener
.set_nonblocking(true)
.map_err(|error| format!("Failed to configure redirect listener: {error}"))?;
let actual_port = listener
.local_addr()
.map_err(|error| format!("Failed to inspect redirect listener: {error}"))?
.port();
parsed
.set_port(Some(actual_port))
.map_err(|_| "failed to render redirect listener port".to_string())?;
Ok((listener, parsed.to_string()))
}
pub(super) fn wait_for_oauth_code(
listener: TcpListener,
redirect_uri: &str,
expected_state: &str,
) -> Result<String, String> {
let query = wait_for_callback_query(listener, redirect_uri, Some(expected_state))?;
query
.into_iter()
.find(|(key, _)| key == "code")
.map(|(_, value)| value)
.ok_or_else(|| "OAuth callback did not include an authorization code".to_string())
}
pub(super) fn wait_for_github_installation(
listener: TcpListener,
redirect_uri: &str,
expected_state: Option<&str>,
) -> Result<String, String> {
let query = wait_for_callback_query(listener, redirect_uri, expected_state)?;
query
.into_iter()
.find(|(key, _)| key == "installation_id")
.map(|(_, value)| value)
.ok_or_else(|| "GitHub callback did not include installation_id".to_string())
}
pub(super) fn wait_for_callback_query(
listener: TcpListener,
redirect_uri: &str,
expected_state: Option<&str>,
) -> Result<Vec<(String, String)>, String> {
let parsed_redirect =
Url::parse(redirect_uri).map_err(|error| format!("Invalid redirect URI: {error}"))?;
let expected_path = parsed_redirect.path().to_string();
let expected_origin = loopback_origin(&parsed_redirect)?;
let deadline = Instant::now() + OAUTH_CALLBACK_TIMEOUT;
loop {
match listener.accept() {
Ok((mut stream, _)) => {
let mut buffer = [0u8; 8192];
let bytes_read = stream
.read(&mut buffer)
.map_err(|error| format!("Failed to read OAuth callback: {error}"))?;
let request = String::from_utf8_lossy(&buffer[..bytes_read]);
let response;
let result = parse_callback_request(
&request,
&expected_path,
expected_state,
&expected_origin,
);
match result {
Ok(query) => {
response = html_response(
200,
"Authorization complete. You can close this window.",
);
let _ = stream.write_all(response.as_bytes());
return Ok(query);
}
Err(error) => {
response = html_response(400, &error);
let _ = stream.write_all(response.as_bytes());
return Err(error);
}
}
}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
if Instant::now() >= deadline {
return Err("OAuth callback timed out after 5 minutes".to_string());
}
std::thread::sleep(Duration::from_millis(50));
}
Err(error) => return Err(format!("Failed to accept OAuth callback: {error}")),
}
}
}
pub(super) fn parse_callback_request(
request: &str,
expected_path: &str,
expected_state: Option<&str>,
expected_origin: &str,
) -> Result<Vec<(String, String)>, String> {
let mut lines = request.lines();
let request_line = lines
.next()
.ok_or_else(|| "OAuth callback request was empty".to_string())?;
let path_and_query = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| "OAuth callback request line was invalid".to_string())?;
let origin = lines.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("origin")
.then(|| value.trim().to_string())
});
if let Some(origin) = origin {
if origin != expected_origin && origin != "null" {
return Err("OAuth callback Origin header did not match the redirect URI".to_string());
}
}
let callback_url = Url::parse(&format!("{expected_origin}{path_and_query}"))
.map_err(|error| format!("OAuth callback URL was invalid: {error}"))?;
if callback_url.path() != expected_path {
return Err("Invalid callback path".to_string());
}
let query = callback_url
.query_pairs()
.map(|(key, value)| (key.into_owned(), value.into_owned()))
.collect::<Vec<_>>();
if let Some(expected_state) = expected_state {
let actual_state = query
.iter()
.find(|(key, _)| key == "state")
.map(|(_, value)| value.as_str());
if actual_state != Some(expected_state) {
return Err("State mismatch".to_string());
}
}
if let Some((_, error)) = query.iter().find(|(key, _)| key == "error") {
return Err(format!("Authorization failed: {error}"));
}
Ok(query)
}
pub(super) fn loopback_origin(url: &Url) -> Result<String, String> {
let host = url
.host_str()
.ok_or_else(|| "Redirect URI must include a host".to_string())?;
let port = url
.port_or_known_default()
.ok_or_else(|| "Redirect URI must include a port".to_string())?;
Ok(format!("{}://{}:{}", url.scheme(), host, port))
}
pub(super) fn html_response(status: u16, message: &str) -> String {
let status_line = match status {
200 => "HTTP/1.1 200 OK",
400 => "HTTP/1.1 400 Bad Request",
_ => "HTTP/1.1 404 Not Found",
};
let title = if status == 200 {
"Authorization Complete"
} else {
"Authorization Failed"
};
let escaped_message = html_escape(message);
format!(
"{status_line}\r\nContent-Type: text/html; charset=utf-8\r\nConnection: close\r\n\r\n<!doctype html><html><head><meta charset=\"utf-8\"><title>{title}</title></head><body><h1>{title}</h1><p>{escaped_message}</p></body></html>"
)
}
fn html_escape(text: &str) -> String {
let mut escaped = String::with_capacity(text.len());
for ch in text.chars() {
match ch {
'&' => escaped.push_str("&"),
'<' => escaped.push_str("<"),
'>' => escaped.push_str(">"),
'"' => escaped.push_str("""),
'\'' => escaped.push_str("'"),
_ => escaped.push(ch),
}
}
escaped
}