use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
};
use crate::{
http::{
restricted::RestrictedResolver, AsyncGenericResolver, AsyncHttpResolver,
SyncGenericResolver, SyncHttpResolver,
},
maybe_send_sync::{MaybeSend, MaybeSync},
settings::Settings,
signer::{BoxedAsyncSigner, BoxedSigner},
AsyncSigner, Error, Result, Signer,
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ProgressPhase {
Reading,
VerifyingManifest,
VerifyingSignature,
VerifyingIngredient,
VerifyingAssetHash,
AddingIngredient,
Thumbnail,
Hashing,
Signing,
Embedding,
FetchingRemoteManifest,
Writing,
FetchingOCSP,
FetchingTimestamp,
}
#[cfg(not(target_arch = "wasm32"))]
pub type ProgressCallbackFunc = dyn Fn(ProgressPhase, u32, u32) -> bool + Send + Sync;
#[cfg(target_arch = "wasm32")]
pub type ProgressCallbackFunc = dyn Fn(ProgressPhase, u32, u32) -> bool;
enum SyncResolverState {
Custom(Arc<dyn SyncHttpResolver>),
Default(OnceLock<Arc<dyn SyncHttpResolver>>),
}
enum AsyncResolverState {
Custom(Arc<dyn AsyncHttpResolver>),
Default(OnceLock<Arc<dyn AsyncHttpResolver>>),
}
enum SignerState {
Custom(BoxedSigner),
FromSettings(OnceLock<Result<BoxedSigner>>),
}
enum AsyncSignerState {
Custom(BoxedAsyncSigner),
FromSettings(OnceLock<Result<BoxedAsyncSigner>>),
}
pub trait IntoSettings {
fn into_settings(self) -> Result<Settings>;
}
impl IntoSettings for Settings {
fn into_settings(self) -> Result<Settings> {
Ok(self)
}
}
impl IntoSettings for &Settings {
fn into_settings(self) -> Result<Settings> {
Ok(self.clone())
}
}
impl IntoSettings for &str {
fn into_settings(self) -> Result<Settings> {
let mut settings = Settings::default();
settings
.update_from_str(self, "json")
.or_else(|_| settings.update_from_str(self, "toml"))?;
Ok(settings)
}
}
impl IntoSettings for String {
fn into_settings(self) -> Result<Settings> {
self.as_str().into_settings()
}
}
impl IntoSettings for serde_json::Value {
fn into_settings(self) -> Result<Settings> {
let json_str = serde_json::to_string(&self).map_err(Error::JsonError)?;
let mut settings = Settings::default();
settings.update_from_str(&json_str, "json")?;
Ok(settings)
}
}
pub struct Context {
settings: Settings,
sync_resolver: SyncResolverState,
async_resolver: AsyncResolverState,
signer: SignerState,
async_signer: AsyncSignerState,
progress_callback: Option<Box<ProgressCallbackFunc>>,
cancel_flag: AtomicBool,
}
impl Default for Context {
fn default() -> Self {
Self {
settings: Settings::default(),
sync_resolver: SyncResolverState::Default(OnceLock::new()),
async_resolver: AsyncResolverState::Default(OnceLock::new()),
#[cfg(test)]
signer: SignerState::Custom(crate::utils::test_signer::test_signer(
crate::SigningAlg::Ps256,
)),
#[cfg(not(test))]
signer: SignerState::FromSettings(OnceLock::new()),
async_signer: AsyncSignerState::FromSettings(OnceLock::new()),
progress_callback: None,
cancel_flag: AtomicBool::new(false),
}
}
}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("settings", &self.settings)
.finish()
}
}
impl Context {
pub fn new() -> Self {
Self::default()
}
pub fn into_shared(self) -> std::sync::Arc<Self> {
self.into()
}
pub fn with_settings<S: IntoSettings>(mut self, settings: S) -> Result<Self> {
self.settings = settings.into_settings()?;
Ok(self)
}
pub fn set_settings<S: IntoSettings>(&mut self, settings: S) -> Result<()> {
self.settings = settings.into_settings()?;
Ok(())
}
pub fn settings(&self) -> &Settings {
&self.settings
}
pub fn settings_mut(&mut self) -> &mut Settings {
&mut self.settings
}
pub fn with_resolver<T: SyncHttpResolver + MaybeSend + MaybeSync + 'static>(
mut self,
resolver: T,
) -> Self {
self.sync_resolver = SyncResolverState::Custom(Arc::new(resolver));
self
}
pub fn set_resolver<T: SyncHttpResolver + MaybeSend + MaybeSync + 'static>(
&mut self,
resolver: T,
) -> Result<()> {
self.sync_resolver = SyncResolverState::Custom(Arc::new(resolver));
Ok(())
}
pub fn with_resolver_async<T: AsyncHttpResolver + MaybeSend + MaybeSync + 'static>(
mut self,
resolver: T,
) -> Self {
self.async_resolver = AsyncResolverState::Custom(Arc::new(resolver));
self
}
pub fn set_resolver_async<T: AsyncHttpResolver + MaybeSend + MaybeSync + 'static>(
&mut self,
resolver: T,
) -> Result<()> {
self.async_resolver = AsyncResolverState::Custom(Arc::new(resolver));
Ok(())
}
pub fn resolver(&self) -> Arc<dyn SyncHttpResolver> {
match &self.sync_resolver {
SyncResolverState::Custom(resolver) => resolver.clone(),
SyncResolverState::Default(once_lock) => once_lock
.get_or_init(|| {
if self.settings.core.allowed_network_hosts.is_some() {
let mut resolver = RestrictedResolver::new(SyncGenericResolver::new());
resolver
.set_allowed_hosts(self.settings.core.allowed_network_hosts.clone());
Arc::new(resolver)
} else {
Arc::new(SyncGenericResolver::with_redirects().unwrap_or_default())
}
})
.clone(),
}
}
pub fn resolver_async(&self) -> Arc<dyn AsyncHttpResolver> {
match &self.async_resolver {
AsyncResolverState::Custom(resolver) => resolver.clone(),
AsyncResolverState::Default(once_lock) => once_lock
.get_or_init(|| {
if self.settings.core.allowed_network_hosts.is_some() {
let mut resolver = RestrictedResolver::new(AsyncGenericResolver::new());
resolver
.set_allowed_hosts(self.settings.core.allowed_network_hosts.clone());
Arc::new(resolver)
} else {
Arc::new(AsyncGenericResolver::with_redirects().unwrap_or_default())
}
})
.clone(),
}
}
pub fn with_signer<T: Signer + MaybeSend + MaybeSync + 'static>(mut self, signer: T) -> Self {
self.signer = SignerState::Custom(Box::new(signer));
self
}
pub fn set_signer<T: Signer + MaybeSend + MaybeSync + 'static>(
&mut self,
signer: T,
) -> Result<()> {
self.signer = SignerState::Custom(Box::new(signer));
Ok(())
}
pub fn set_async_signer<T: AsyncSigner + MaybeSend + MaybeSync + 'static>(
&mut self,
signer: T,
) -> Result<()> {
self.async_signer = AsyncSignerState::Custom(Box::new(signer));
Ok(())
}
pub fn signer(&self) -> Result<&dyn Signer> {
match &self.signer {
SignerState::Custom(signer) => Ok(signer.as_ref()),
SignerState::FromSettings(once_lock) => {
let result = once_lock.get_or_init(|| {
if let Some(signer_settings) = &self.settings.signer {
let c2pa_signer = signer_settings.clone().c2pa_signer()?;
if let Some(cawg_settings) = &self.settings.cawg_x509_signer {
cawg_settings.clone().cawg_signer(c2pa_signer)
} else {
Ok(c2pa_signer)
}
} else {
Err(Error::MissingSignerSettings)
}
});
match result {
Ok(boxed) => Ok(boxed.as_ref()),
Err(Error::MissingSignerSettings) => Err(Error::MissingSignerSettings),
Err(e) => Err(Error::BadParam(format!(
"failed to create signer from settings: {e}"
))),
}
}
}
}
pub fn async_signer(&self) -> Result<&dyn AsyncSigner> {
match &self.async_signer {
AsyncSignerState::Custom(signer) => Ok(signer.as_ref()),
AsyncSignerState::FromSettings(once_lock) => {
let result = once_lock.get_or_init(|| {
Err(Error::BadParam(
"Async signer not configured in settings".to_string(),
))
});
match result {
Ok(boxed) => Ok(boxed.as_ref()),
Err(e) => Err(Error::BadParam(format!(
"failed to create async signer from settings: {e}"
))),
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_async_signer(mut self, signer: impl AsyncSigner + Send + Sync + 'static) -> Self {
self.async_signer = AsyncSignerState::Custom(Box::new(signer));
self
}
#[cfg(target_arch = "wasm32")]
pub fn with_async_signer(mut self, signer: impl AsyncSigner + 'static) -> Self {
self.async_signer = AsyncSignerState::Custom(Box::new(signer));
self
}
pub fn with_progress_callback<F>(mut self, callback: F) -> Self
where
F: Fn(ProgressPhase, u32, u32) -> bool + MaybeSend + MaybeSync + 'static,
{
self.progress_callback = Some(Box::new(callback));
self
}
pub fn set_progress_callback<F>(&mut self, callback: F)
where
F: Fn(ProgressPhase, u32, u32) -> bool + MaybeSend + MaybeSync + 'static,
{
self.progress_callback = Some(Box::new(callback));
}
pub fn cancel(&self) {
self.cancel_flag.store(true, Ordering::Release);
}
pub fn is_cancelled(&self) -> bool {
self.cancel_flag.load(Ordering::Acquire)
}
pub(crate) fn check_progress(&self, phase: ProgressPhase, step: u32, total: u32) -> Result<()> {
log::info!("progress: phase={phase:?} step={step}/{total}");
if let Some(cb) = self.progress_callback.as_deref() {
if !cb(phase, step, total) {
return Err(Error::OperationCancelled);
}
}
if self.cancel_flag.load(Ordering::Acquire) {
return Err(Error::OperationCancelled);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[cfg(not(target_arch = "wasm32"))]
use crate::utils::test_signer::async_test_signer;
use crate::{
utils::{test::test_context, test_signer::test_signer},
SigningAlg,
};
#[test]
fn test_into_settings_from_settings() {
let mut settings = Settings::default();
settings.verify.verify_after_sign = true;
let context = Context::new().with_settings(settings).unwrap();
assert!(context.settings().verify.verify_after_sign);
}
#[test]
fn test_into_settings_from_json_str() {
let json = r#"{"verify": {"verify_after_sign": true}}"#;
let context = Context::new().with_settings(json).unwrap();
assert!(context.settings().verify.verify_after_sign);
}
#[test]
fn test_into_settings_from_toml_str() {
let toml = r#"
[verify]
verify_after_sign = true
"#;
let context = Context::new().with_settings(toml).unwrap();
assert!(context.settings().verify.verify_after_sign);
}
#[test]
fn test_into_settings_from_json_value() {
let value = serde_json::json!({"verify": {"verify_after_sign": true}});
let context = Context::new().with_settings(value).unwrap();
assert!(context.settings().verify.verify_after_sign);
}
#[test]
fn test_into_settings_invalid_json() {
let invalid_json = r#"{"verify": {"verify_after_sign": "#;
let result = Context::new().with_settings(invalid_json);
assert!(result.is_err());
}
#[test]
fn test_signer_from_settings() {
let context = test_context();
let signer = context.signer();
assert!(signer.is_ok(), "Signer should be created from settings");
let signer = signer.unwrap();
assert!(
signer.alg() == SigningAlg::Ps256,
"Signer from settings should have Ps256 algorithm"
);
let signer2 = context.signer();
assert!(signer2.is_ok(), "Cached signer should be returned");
}
#[test]
fn test_signer_missing_settings() {
let mut context = Context {
settings: Settings::default(),
sync_resolver: SyncResolverState::Default(OnceLock::new()),
async_resolver: AsyncResolverState::Default(OnceLock::new()),
signer: SignerState::FromSettings(OnceLock::new()),
async_signer: AsyncSignerState::FromSettings(OnceLock::new()),
progress_callback: None,
cancel_flag: AtomicBool::new(false),
};
context.settings.signer = None;
context.settings.cawg_x509_signer = None;
let result = context.signer();
assert!(
result.is_err(),
"Should error when no signer settings present"
);
assert!(
matches!(result, Err(Error::MissingSignerSettings)),
"Expected MissingSignerSettings error, got: {}",
match result {
Ok(_) => "Ok(Signer)".to_string(),
Err(ref e) => format!("Err({e:?})"),
}
);
}
#[test]
fn test_custom_signer() {
let custom_signer = test_signer(SigningAlg::Es256);
let context = Context::new().with_signer(custom_signer);
let signer = context.signer().unwrap();
assert!(
signer.alg() == SigningAlg::Es256,
"Custom signer should have Es256 algorithm"
);
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_custom_async_signer() {
let custom_async_signer = async_test_signer(SigningAlg::Es256);
let context = Context::new().with_async_signer(custom_async_signer);
let async_signer = context.async_signer().unwrap();
assert_eq!(
async_signer.alg(),
SigningAlg::Es256,
"Custom async signer should have Es256 algorithm"
);
let certs = async_signer.certs().unwrap();
assert!(!certs.is_empty(), "Async signer should have certificates");
let signature = async_signer.sign(vec![1, 2, 3, 4]).await;
assert!(signature.is_ok(), "Sign should succeed");
}
#[test]
fn test_async_signer_missing_settings() {
let context = Context {
settings: Settings::default(),
sync_resolver: SyncResolverState::Default(OnceLock::new()),
async_resolver: AsyncResolverState::Default(OnceLock::new()),
signer: SignerState::FromSettings(OnceLock::new()),
async_signer: AsyncSignerState::FromSettings(OnceLock::new()),
progress_callback: None,
cancel_flag: AtomicBool::new(false),
};
let result = context.async_signer();
assert!(
result.is_err(),
"Should error when no async signer settings present"
);
assert!(
matches!(result, Err(Error::BadParam(_))),
"Expected BadParam error"
);
}
#[test]
fn test_check_progress_no_callback_ok() {
let context = Context::new();
let result = context.check_progress(ProgressPhase::Hashing, 1, 1);
assert!(result.is_ok());
}
#[test]
fn test_check_progress_cancelled_returns_error() {
let context = Context::new();
context.cancel();
let result = context.check_progress(ProgressPhase::Signing, 1, 1);
assert!(matches!(result, Err(Error::OperationCancelled)));
}
#[test]
fn test_check_progress_callback_false_cancels() {
let context = Context::new().with_progress_callback(|_, _, _| false);
let result = context.check_progress(ProgressPhase::Reading, 1, 1);
assert!(matches!(result, Err(Error::OperationCancelled)));
}
#[test]
fn test_check_progress_callback_receives_phase_and_steps() {
use std::sync::Mutex;
let received: std::sync::Arc<Mutex<Vec<(ProgressPhase, u32, u32)>>> =
std::sync::Arc::new(Mutex::new(Vec::new()));
let received_clone = received.clone();
let context = Context::new().with_progress_callback(move |phase, step, total| {
received_clone.lock().unwrap().push((phase, step, total));
true
});
context
.check_progress(ProgressPhase::Thumbnail, 1, 1)
.unwrap();
context
.check_progress(ProgressPhase::Hashing, 3, 10)
.unwrap();
let r = received.lock().unwrap();
assert_eq!(r.len(), 2);
assert_eq!(r[0], (ProgressPhase::Thumbnail, 1, 1));
assert_eq!(r[1], (ProgressPhase::Hashing, 3, 10));
}
#[test]
fn test_check_progress_indeterminate_total_passes_through() {
use std::sync::Mutex;
let received: std::sync::Arc<Mutex<Vec<(ProgressPhase, u32, u32)>>> =
std::sync::Arc::new(Mutex::new(Vec::new()));
let received_clone = received.clone();
let context = Context::new().with_progress_callback(move |phase, step, total| {
received_clone.lock().unwrap().push((phase, step, total));
true
});
context
.check_progress(ProgressPhase::Hashing, 1, 0)
.unwrap();
context
.check_progress(ProgressPhase::Hashing, 2, 0)
.unwrap();
let r = received.lock().unwrap();
assert_eq!(r.len(), 2);
assert_eq!(r[0], (ProgressPhase::Hashing, 1, 0));
assert_eq!(r[1], (ProgressPhase::Hashing, 2, 0));
}
#[test]
fn test_cancel_flag_checked_between_callbacks() {
let context = Context::new();
assert!(context.check_progress(ProgressPhase::Hashing, 1, 0).is_ok());
context.cancel();
assert!(matches!(
context.check_progress(ProgressPhase::Hashing, 2, 0),
Err(Error::OperationCancelled)
));
}
#[test]
fn test_is_cancelled_after_cancel() {
let context = Context::new();
assert!(!context.is_cancelled());
context.cancel();
assert!(context.is_cancelled());
}
#[test]
fn test_default_sync_resolver() {
let context = Context::new();
let _resolver = context.resolver();
}
#[test]
fn test_default_async_resolver() {
let context = Context::new();
let _resolver = context.resolver_async();
}
#[test]
fn test_custom_sync_resolver() {
use std::io::Read;
use http::{Request, Response};
use crate::http::SyncHttpResolver;
struct MockSyncResolver;
impl SyncHttpResolver for MockSyncResolver {
fn http_resolve(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, crate::http::HttpResolverError> {
Ok(
Response::builder()
.status(200)
.body(Box::new(std::io::Cursor::new(b"mock response".to_vec()))
as Box<dyn Read>)
.unwrap(),
)
}
}
let context = Context::new().with_resolver(MockSyncResolver);
let resolver = context.resolver();
let request = Request::builder()
.uri("http://example.com")
.body(vec![])
.unwrap();
let response = resolver.http_resolve(request);
assert!(response.is_ok(), "Mock resolver should succeed");
let mut body = response.unwrap().into_body();
let mut buffer = Vec::new();
body.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, b"mock response");
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_custom_async_resolver() {
use std::io::Read;
use async_trait::async_trait;
use http::{Request, Response};
use crate::http::AsyncHttpResolver;
struct MockAsyncResolver;
#[async_trait]
impl AsyncHttpResolver for MockAsyncResolver {
async fn http_resolve_async(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, crate::http::HttpResolverError> {
Ok(Response::builder()
.status(200)
.body(
Box::new(std::io::Cursor::new(b"mock async response".to_vec()))
as Box<dyn Read>,
)
.unwrap())
}
}
let context = Context::new().with_resolver_async(MockAsyncResolver);
let resolver = context.resolver_async();
let request = Request::builder()
.uri("http://example.com")
.body(vec![])
.unwrap();
let response = resolver.http_resolve_async(request).await;
assert!(response.is_ok(), "Mock async resolver should succeed");
let mut body = response.unwrap().into_body();
let mut buffer = Vec::new();
body.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, b"mock async response");
}
#[test]
fn test_resolver_with_allowed_hosts() {
let settings_toml = r#"
[core]
allowed_network_hosts = ["example.com", "test.org"]
"#;
let context = Context::new().with_settings(settings_toml).unwrap();
let _resolver = context.resolver();
assert_eq!(
context
.settings()
.core
.allowed_network_hosts
.as_ref()
.unwrap()
.len(),
2,
"Should have 2 allowed hosts configured"
);
}
#[test]
fn test_resolver_caching() {
let context = Context::new();
let _resolver1 = context.resolver();
let _resolver2 = context.resolver();
let _resolver3 = context.resolver();
}
#[test]
fn test_async_resolver_caching() {
let context = Context::new();
let _resolver1 = context.resolver_async();
let _resolver2 = context.resolver_async();
let _resolver3 = context.resolver_async();
}
#[test]
fn test_set_resolver() {
use std::io::Read;
use http::{Request, Response};
use crate::http::SyncHttpResolver;
struct CustomResolver;
impl SyncHttpResolver for CustomResolver {
fn http_resolve(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, crate::http::HttpResolverError> {
let body = Box::new(std::io::Cursor::new(b"custom sync response".to_vec()));
Ok(Response::builder()
.status(200)
.body(body as Box<dyn Read>)
.unwrap())
}
}
let mut context = Context::new();
let result = context.set_resolver(CustomResolver);
assert!(result.is_ok(), "set_resolver should succeed");
let resolver = context.resolver();
let request = Request::builder()
.uri("http://example.com")
.body(Vec::new())
.unwrap();
let response = resolver.http_resolve(request);
assert!(response.is_ok(), "Custom resolver should be callable");
let mut body = response.unwrap().into_body();
let mut buffer = Vec::new();
body.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, b"custom sync response");
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_set_resolver_async() {
use std::io::Read;
use async_trait::async_trait;
use http::{Request, Response};
use crate::http::AsyncHttpResolver;
struct CustomAsyncResolver;
#[async_trait]
impl AsyncHttpResolver for CustomAsyncResolver {
async fn http_resolve_async(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, crate::http::HttpResolverError> {
Ok(Response::builder()
.status(200)
.body(
Box::new(std::io::Cursor::new(b"custom async response".to_vec()))
as Box<dyn Read>,
)
.unwrap())
}
}
let mut context = Context::new();
let result = context.set_resolver_async(CustomAsyncResolver);
assert!(result.is_ok(), "set_resolver_async should succeed");
let resolver = context.resolver_async();
let request = Request::builder()
.uri("http://example.com")
.body(Vec::new())
.unwrap();
let response = resolver.http_resolve_async(request).await;
assert!(response.is_ok(), "Custom async resolver should be callable");
let mut body = response.unwrap().into_body();
let mut buffer = Vec::new();
body.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, b"custom async response");
}
#[test]
fn test_set_signer() {
let custom_signer = test_signer(SigningAlg::Es256);
let mut context = Context::new();
let result = context.set_signer(custom_signer);
assert!(result.is_ok(), "set_signer should succeed");
let signer = context.signer();
assert!(signer.is_ok(), "Should be able to retrieve custom signer");
let signer = signer.unwrap();
assert_eq!(signer.alg(), SigningAlg::Es256, "Signer should be Es256");
let signature = signer.sign(b"test data");
assert!(
signature.is_ok(),
"Should be able to sign with custom signer"
);
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_set_async_signer() {
let custom_signer = async_test_signer(SigningAlg::Es256);
let mut context = Context::new();
let result = context.set_async_signer(custom_signer);
assert!(result.is_ok(), "set_async_signer should succeed");
let signer = context.async_signer();
assert!(
signer.is_ok(),
"Should be able to retrieve custom async signer"
);
let signer = signer.unwrap();
assert_eq!(
signer.alg(),
SigningAlg::Es256,
"Async signer should be Es256"
);
let signature = signer.sign(b"test data".to_vec()).await;
assert!(
signature.is_ok(),
"Should be able to sign with custom async signer"
);
}
#[test]
fn test_set_methods_replace_previous_values() {
let initial_signer = test_signer(SigningAlg::Ps256);
let mut context = Context::new().with_signer(initial_signer);
let signer = context.signer().unwrap();
assert_eq!(
signer.alg(),
SigningAlg::Ps256,
"Initial signer should be Ps256"
);
let new_signer = test_signer(SigningAlg::Es256);
context.set_signer(new_signer).unwrap();
let signer = context.signer().unwrap();
assert_eq!(
signer.alg(),
SigningAlg::Es256,
"Signer should now be Es256"
);
}
}