use crate::error::ApiError;
use core::str;
use minicbor::Decoder;
use ockam::identity::utils::now;
use ockam::identity::TRUST_CONTEXT_ID;
use ockam::identity::{
AttributesEntry, Identifier, IdentityAttributesWriter, IdentitySecureChannelLocalInfo,
};
use ockam_core::api::{Method, RequestHeader, Response};
use ockam_core::compat::sync::Arc;
use ockam_core::{self, Result, Routed, Worker};
use ockam_node::Context;
use reqwest::StatusCode;
use std::collections::HashMap;
use tracing::trace;
pub struct Server {
attributes_writer: Arc<dyn IdentityAttributesWriter>,
project: String,
tenant_base_url: String,
certificate: reqwest::Certificate,
attributes: Vec<String>,
}
#[ockam_core::worker]
impl Worker for Server {
type Context = Context;
type Message = Vec<u8>;
async fn handle_message(&mut self, c: &mut Context, m: Routed<Self::Message>) -> Result<()> {
if let Ok(i) = IdentitySecureChannelLocalInfo::find_info(m.local_message()) {
let r = self.on_request(&i.their_identity_id(), m.as_body()).await?;
c.send(m.return_route(), r).await
} else {
let mut dec = Decoder::new(m.as_body());
let req: RequestHeader = dec.decode()?;
let res = Response::forbidden(&req, "secure channel required").to_vec()?;
c.send(m.return_route(), res).await
}
}
}
impl Server {
pub fn new(
attributes_writer: Arc<dyn IdentityAttributesWriter>,
project: String,
tenant_base_url: &str,
certificate: &str,
attributes: &[String],
) -> Result<Self> {
let certificate = reqwest::Certificate::from_pem(certificate.as_bytes())
.map_err(|err| ApiError::core(err.to_string()))?;
Ok(Server {
attributes_writer,
project,
tenant_base_url: tenant_base_url.to_string(),
certificate,
attributes: attributes.iter().map(|s| s.to_string()).collect(),
})
}
async fn on_request(&mut self, from: &Identifier, data: &[u8]) -> Result<Vec<u8>> {
let mut dec = Decoder::new(data);
let req: RequestHeader = dec.decode()?;
trace! {
target: "ockam_api::okta::server",
from = %from,
id = %req.id(),
method = ?req.method(),
path = %req.path(),
body = %req.has_body(),
"request"
}
let res = match req.method() {
Some(Method::Post) => match req.path_segments::<2>().as_slice() {
["v0", "enroll"] => {
debug!("Checking token for project {:?}", self.project);
let token: crate::cloud::enroll::auth0::AuthenticateOidcToken = dec.decode()?;
debug!("device code received: {token:#?}");
if let Some(attrs) = self.check_token(&token.access_token.0).await? {
let entry = AttributesEntry::new(
attrs
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.chain(
[(
TRUST_CONTEXT_ID.to_owned(),
self.project.as_bytes().to_vec(),
)]
.into_iter(),
)
.collect(),
now().unwrap(),
None,
None,
);
self.attributes_writer.put_attributes(from, entry).await?;
Response::ok(&req).to_vec()?
} else {
Response::forbidden(&req, "Forbidden").to_vec()?
}
}
_ => Response::unknown_path(&req).to_vec()?,
},
_ => Response::invalid_method(&req).to_vec()?,
};
Ok(res)
}
async fn check_token(&mut self, token: &str) -> Result<Option<HashMap<String, String>>> {
let client = reqwest::ClientBuilder::new()
.tls_built_in_root_certs(false)
.add_root_certificate(self.certificate.clone())
.build()
.map_err(|err| ApiError::core(err.to_string()))?;
let res = client
.get(format!("{}/v1/userinfo", &self.tenant_base_url))
.header("Authorization", format!("Bearer {token}"))
.send()
.await;
if let Ok(res) = res {
match res.status() {
StatusCode::OK => {
let doc: HashMap<String, serde_json::Value> = res
.json()
.await
.map_err(|_err| ApiError::core("Failed to authenticate with Okta"))?;
debug!("userinfo received: {doc:?}");
let mut custom_attrs = HashMap::new();
for a in self.attributes.iter() {
if let Some(v) = doc.get(a).and_then(|v| v.as_str()) {
custom_attrs.insert(a.to_owned(), v.to_string());
}
}
Ok(Some(custom_attrs))
}
_ => Ok(None),
}
} else {
Ok(None)
}
}
}