use dioxus::prelude::*;
use crate::app::Route;
use crate::state::app_state::{AppState, AuthStatus};
#[component]
pub fn OidcLoginButton(homeserver: String) -> Element {
let mut state = use_context::<Signal<AppState>>();
let navigator = use_navigator();
let mut is_loading = use_signal(|| false);
let mut error_message = use_signal(|| Option::<String>::None);
let mut oidc_available = use_signal(|| Option::<bool>::None);
let mut provider_name = use_signal(|| String::from("OIDC"));
let hs_check = homeserver.clone();
use_effect(move || {
let hs = hs_check.clone();
spawn(async move {
match check_oidc_support(&hs).await {
Ok(Some(name)) => {
provider_name.set(name);
oidc_available.set(Some(true));
}
Ok(None) => {
oidc_available.set(Some(false));
}
Err(e) => {
tracing::warn!("Could not check OIDC support: {e}");
oidc_available.set(Some(false));
}
}
});
});
let on_click = move |_| {
is_loading.set(true);
error_message.set(None);
let hs = homeserver.clone();
spawn(async move {
match start_oidc_login(&hs).await {
Ok(client) => {
tracing::info!("OIDC login successful");
{
let mut w = state.write();
w.auth_status = AuthStatus::LoggedIn;
w.client = Some(client);
}
navigator.push(Route::Home {});
}
Err(e) => {
tracing::error!("OIDC login failed: {e}");
error_message.set(Some(e));
is_loading.set(false);
}
}
});
};
let is_available = *oidc_available.read();
if is_available == Some(false) {
return rsx! {};
}
let provider = provider_name.read().clone();
let btn_label_loading = String::from("Opening browser...");
let btn_label_ready = format!("Sign in with {}", provider);
rsx! {
div {
class: "oidc-login",
if let Some(ref err) = *error_message.read() {
div {
class: "oidc-login__error",
"{err}"
}
}
if is_available == Some(true) {
button {
class: "oidc-login-button",
onclick: on_click,
disabled: *is_loading.read(),
if *is_loading.read() {
"{btn_label_loading}"
} else {
"{btn_label_ready}"
}
}
} else {
div {
class: "oidc-login__checking",
"Checking OIDC support..."
}
}
}
}
}
async fn check_oidc_support(homeserver: &str) -> Result<Option<String>, String> {
let client = crate::client::build_client(homeserver)
.await
.map_err(|e| format!("Failed to connect: {e}"))?;
let login_types = client
.matrix_auth()
.get_login_types()
.await
.map_err(|e| format!("Failed to get login types: {e}"))?;
for flow in &login_types.flows {
if let matrix_sdk::ruma::api::client::session::get_login_types::v3::LoginType::Sso(sso) =
flow
{
for provider in &sso.identity_providers {
let id = provider.id.as_str().to_lowercase();
let name = provider.name.as_str();
if id.contains("oidc")
|| id.contains("openid")
|| name.to_lowercase().contains("oidc")
|| name.to_lowercase().contains("openid")
{
return Ok(Some(name.to_string()));
}
}
if !sso.identity_providers.is_empty() {
let first_name = sso.identity_providers[0].name.clone();
return Ok(Some(first_name.to_string()));
}
}
}
Ok(None)
}
async fn start_oidc_login(homeserver: &str) -> Result<matrix_sdk::Client, String> {
let client = crate::client::build_client(homeserver)
.await
.map_err(|e| format!("Failed to connect: {e}"))?;
let sso_result = client
.matrix_auth()
.login_sso(|url| async move {
tracing::info!("Opening OIDC/SSO URL: {url}");
#[cfg(target_os = "windows")]
{
let _ = std::process::Command::new("cmd")
.args(["/C", "start", "", &url])
.spawn();
}
#[cfg(target_os = "macos")]
{
let _ = std::process::Command::new("open").arg(&url).spawn();
}
#[cfg(target_os = "linux")]
{
let _ = std::process::Command::new("xdg-open").arg(&url).spawn();
}
Ok(())
})
.initial_device_display_name("Netrix")
.await;
match sso_result {
Ok(_) => {
if let Some(session) = client.matrix_auth().session() {
let session_data = crate::persistence::matrix_state::SessionData {
homeserver_url: homeserver.to_string(),
user_id: session.meta.user_id.to_string(),
device_id: session.meta.device_id.to_string(),
access_token: session.tokens.access_token.clone(),
};
if let Err(e) =
crate::persistence::matrix_state::save_session(&session_data).await
{
tracing::error!("Failed to save OIDC session: {e}");
}
}
Ok(client)
}
Err(e) => Err(format!("OIDC login failed: {e}")),
}
}