use crate::authenticator::{AuthorityMember, AuthorityMembersRepository};
use crate::error::ApiError;
use crate::orchestrator::enroll::auth0::AuthenticateOidcToken;
use core::str;
use ockam::identity::utils::now;
use ockam::identity::Identifier;
use ockam_core::api::{Method, Request, Response};
use ockam_core::{self, Decodable, Result, Routed, SecureChannelLocalInfo, Worker};
use ockam_node::Context;
use reqwest::StatusCode;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::trace;
pub struct Server {
authority: Identifier,
member_attributes_repository: Arc<dyn AuthorityMembersRepository>,
tenant_base_url: String,
certificate: reqwest::Certificate,
attributes: Vec<String>,
}
#[ockam_core::worker]
impl Worker for Server {
type Context = Context;
type Message = Request<Vec<u8>>;
async fn handle_message(&mut self, c: &mut Context, m: Routed<Self::Message>) -> Result<()> {
if let Ok(i) = SecureChannelLocalInfo::find_info(m.local_message()) {
let return_route = m.return_route().clone();
let request = m.into_body()?;
let reply = self
.on_request(&i.their_identifier().into(), request)
.await?;
c.send(return_route, reply).await
} else {
let return_route = m.return_route().clone();
let request = m.into_body()?;
let res = Response::forbidden(request.header(), "secure channel required");
c.send(return_route, res).await
}
}
}
impl Server {
pub fn new(
authority: &Identifier,
member_attributes_repository: Arc<dyn AuthorityMembersRepository>,
tenant_base_url: &str,
certificate: &str,
attributes: &[String],
) -> Result<Self> {
let certificate = reqwest::Certificate::from_pem(certificate.as_bytes())
.map_err(|err| ApiError::core(err.to_string()))?;
Ok(Server {
authority: authority.clone(),
member_attributes_repository,
tenant_base_url: tenant_base_url.to_string(),
certificate,
attributes: attributes.iter().map(|s| s.to_string()).collect(),
})
}
async fn on_request(
&mut self,
from: &Identifier,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>> {
let (header, body) = request.into_parts();
trace! {
target: "ockam_api::okta::server",
from = %from,
id = %header.id(),
method = ?header.method(),
path = %header.path(),
body = %header.has_body(),
"request"
}
let res = match header.method() {
Some(Method::Post) => match header.path_segments::<2>().as_slice() {
["v0", "enroll"] => {
debug!("Checking token");
let token = AuthenticateOidcToken::decode(&body.unwrap_or_default())?;
debug!("device code received: {token:#?}");
if let Some(attrs) = self.check_token(&token.access_token.0).await? {
let attrs = attrs
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
let member =
AuthorityMember::new(from.clone(), attrs, from.clone(), now()?, false);
self.member_attributes_repository
.add_member(&self.authority, member)
.await?;
Response::ok().with_headers(&header).encode_body()?
} else {
Response::forbidden(&header, "Forbidden").encode_body()?
}
}
_ => Response::unknown_path(&header).encode_body()?,
},
_ => Response::invalid_method(&header).encode_body()?,
};
Ok(res)
}
async fn check_token(&mut self, token: &str) -> Result<Option<HashMap<String, String>>> {
let client = reqwest::ClientBuilder::new()
.tls_built_in_root_certs(false)
.add_root_certificate(self.certificate.clone())
.build()
.map_err(|err| ApiError::core(err.to_string()))?;
let res = client
.get(format!("{}/v1/userinfo", &self.tenant_base_url))
.header("Authorization", format!("Bearer {token}"))
.send()
.await;
if let Ok(res) = res {
match res.status() {
StatusCode::OK => {
let doc: HashMap<String, serde_json::Value> = res
.json()
.await
.map_err(|_err| ApiError::core("Failed to authenticate with Okta"))?;
debug!("userinfo received: {doc:?}");
let mut custom_attrs = HashMap::new();
for a in self.attributes.iter() {
if let Some(v) = doc.get(a).and_then(|v| v.as_str()) {
custom_attrs.insert(a.to_owned(), v.to_string());
}
}
Ok(Some(custom_attrs))
}
_ => Ok(None),
}
} else {
Ok(None)
}
}
}