use std::collections::HashMap;
use reifydb_core::interface::auth::AuthStep;
use reifydb_transaction::transaction::Transaction;
use reifydb_type::error::Error;
use tracing::instrument;
use super::{AuthResponse, AuthService, generate_session_token};
use crate::error::AuthError;
impl AuthService {
#[instrument(name = "auth::authenticate", level = "debug", skip(self, credentials))]
pub fn authenticate(&self, method: &str, credentials: HashMap<String, String>) -> Result<AuthResponse, Error> {
if let Some(challenge_id) = credentials.get("challenge_id").cloned() {
return self.authenticate_challenge_response(&challenge_id, credentials);
}
if method == "token" {
return self.authenticate_token(credentials);
}
let identifier = credentials.get("identifier").map(|s| s.as_str()).unwrap_or("");
let mut txn = self.engine.begin_query()?;
let catalog = self.engine.catalog();
let ident = match catalog.find_identity_by_name(&mut Transaction::Query(&mut txn), identifier)? {
Some(u) => u,
None => {
drop(txn);
if method == "solana"
&& let Some(public_key) = credentials.get("public_key").cloned()
{
return self.auto_provision_solana(identifier, &public_key, &credentials);
}
return Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
});
}
};
if !ident.enabled {
return Ok(AuthResponse::Failed {
reason: "identity is disabled".to_string(),
});
}
let stored_auth = match catalog.find_authentication_by_identity_and_method(
&mut Transaction::Query(&mut txn),
ident.id,
method,
)? {
Some(a) => a,
None => {
return Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
});
}
};
let provider = self.auth_registry.get(method).ok_or_else(|| {
Error::from(AuthError::UnknownMethod {
method: method.to_string(),
})
})?;
match provider.authenticate(&stored_auth.properties, &credentials)? {
AuthStep::Authenticated => {
let token = generate_session_token(&self.rng);
self.persist_token(&token, ident.id)?;
Ok(AuthResponse::Authenticated {
identity: ident.id,
token,
})
}
AuthStep::Failed => Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
}),
AuthStep::Challenge {
payload,
} => {
let challenge_id = self.challenges.create(
identifier.to_string(),
method.to_string(),
payload.clone(),
&self.clock,
&self.rng,
);
Ok(AuthResponse::Challenge {
challenge_id,
payload,
})
}
}
}
fn authenticate_token(&self, credentials: HashMap<String, String>) -> Result<AuthResponse, Error> {
let token_value = match credentials.get("token") {
Some(t) if !t.is_empty() => t,
_ => {
return Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
});
}
};
match self.validate_token(token_value) {
Some(token) => {
let session_token = generate_session_token(&self.rng);
self.persist_token(&session_token, token.identity)?;
Ok(AuthResponse::Authenticated {
identity: token.identity,
token: session_token,
})
}
None => Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
}),
}
}
fn authenticate_challenge_response(
&self,
challenge_id: &str,
mut credentials: HashMap<String, String>,
) -> Result<AuthResponse, Error> {
let challenge = match self.challenges.consume(challenge_id) {
Some(c) => c,
None => {
return Ok(AuthResponse::Failed {
reason: "invalid or expired challenge".to_string(),
});
}
};
for (k, v) in &challenge.payload {
credentials.entry(k.clone()).or_insert_with(|| v.clone());
}
credentials.remove("challenge_id");
let mut txn = self.engine.begin_query()?;
let catalog = self.engine.catalog();
let ident = match catalog
.find_identity_by_name(&mut Transaction::Query(&mut txn), &challenge.identifier)?
{
Some(u) if u.enabled => u,
_ => {
return Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
});
}
};
let stored_auth = match catalog.find_authentication_by_identity_and_method(
&mut Transaction::Query(&mut txn),
ident.id,
&challenge.method,
)? {
Some(a) => a,
None => {
return Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
});
}
};
let provider = self.auth_registry.get(&challenge.method).ok_or_else(|| {
Error::from(AuthError::UnknownMethod {
method: challenge.method.clone(),
})
})?;
match provider.authenticate(&stored_auth.properties, &credentials)? {
AuthStep::Authenticated => {
let token = generate_session_token(&self.rng);
self.persist_token(&token, ident.id)?;
Ok(AuthResponse::Authenticated {
identity: ident.id,
token,
})
}
AuthStep::Failed => Ok(AuthResponse::Failed {
reason: "invalid credentials".to_string(),
}),
AuthStep::Challenge {
..
} => Ok(AuthResponse::Failed {
reason: "nested challenges are not supported".to_string(),
}),
}
}
}