use std::{
ffi::CString,
future::Future,
os::raw::{c_uint, c_void},
str::FromStr,
sync::{Arc, Weak},
time::Duration,
};
use async_fn_traits::AsyncFn2;
use async_trait::async_trait;
use extern_c::extern_c;
use ffi_sdk::{FsComponent, TransportConfigMode};
pub use self::config::{DittoConfig, DittoConfigConnect};
use crate::{
ditto::{
init::config::{ActualConfig, InternalConfig},
DittoFields,
},
error,
identity::DittoAuthenticator,
small_peer_info::SmallPeerInfo,
utils::{make_continuation, prelude::*},
warn,
};
pub(crate) mod config;
impl Ditto {
pub async fn open(config: DittoConfig) -> Result<Ditto, DittoError> {
Self::init_sdk_version();
let default_root_dir = default_root_directory();
let customer_config = config;
let mut internal_config = InternalConfig {
legacy_persistence_directory: None,
};
if customer_config.persistence_directory.is_none() {
if let Some(root) = &default_root_dir {
internal_config.legacy_persistence_directory = Some(root.clone());
}
}
let actual_config = ActualConfig {
customer_facing: customer_config,
internal: internal_config,
};
let config_cbor: &[u8] =
&serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
let (continuation, recv) = make_continuation();
let default_root_dir_ref = default_root_dir
.as_deref()
.and_then(|root| root.to_str())
.ok_or_else(|| {
DittoError::new(
ErrorKind::IO,
"Unable to resolve a default data directory on this platform".to_string(),
)
})?;
let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
.expect("should construct CString from no-nulls &str");
let default_root_dir_cstr = &*default_root_dir_cstring;
let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
#[allow(clippy::useless_conversion)]
ffi_sdk::dittoffi_ditto_open_async_throws(
config_cbor.into(),
TransportConfigMode::PlatformIndependent,
default_root_dir_charp,
continuation.into(),
);
let ffi_result = recv.await.unwrap();
let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_result.into_rust_result()?;
Self::finish_open(ffi_ditto, actual_config.customer_facing)
}
pub fn open_sync(config: DittoConfig) -> Result<Ditto, DittoError> {
Self::init_sdk_version();
let default_root_dir = default_root_directory();
let customer_config = config;
let mut internal_config = InternalConfig {
legacy_persistence_directory: None,
};
if customer_config.persistence_directory.is_none() {
if let Some(root) = &default_root_dir {
internal_config.legacy_persistence_directory = Some(root.clone());
}
}
let actual_config = ActualConfig {
customer_facing: customer_config,
internal: internal_config,
};
let config_cbor: &[u8] =
&serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
let default_root_dir_ref = default_root_dir
.as_deref()
.and_then(|root| root.to_str())
.ok_or_else(|| {
DittoError::new(
ErrorKind::IO,
"Unable to resolve a default data directory on this platform".to_string(),
)
})?;
let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
.expect("should construct CString from no-nulls &str");
let default_root_dir_cstr = &*default_root_dir_cstring;
let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_sdk::dittoffi_ditto_open_throws(
config_cbor.into(),
TransportConfigMode::PlatformIndependent,
default_root_dir_charp,
)
.into_rust_result()?;
Self::finish_open(ffi_ditto, actual_config.customer_facing)
}
fn finish_open(
ffi_ditto: repr_c::Box<ffi_sdk::Ditto>,
config: DittoConfig,
) -> Result<Ditto, DittoError> {
let ditto: Arc<repr_c::Box<ffi_sdk::Ditto>> = Arc::new(ffi_ditto);
let has_auth = matches!(&config.connect, DittoConfigConnect::Server { .. });
let disk_usage = DiskUsage::new(ditto.retain(), FsComponent::Root);
let small_peer_info = SmallPeerInfo::new(ditto.retain());
let fields = Arc::new_cyclic(|weak_fields: &arc::Weak<_>| {
let store = Store::new(ditto.retain(), weak_fields.clone());
let sync = crate::sync::Sync::new(weak_fields.clone());
let presence = Arc::new(Presence::new(weak_fields.clone()));
DittoFields {
ditto: ditto.retain(),
has_auth,
config,
store,
sync,
presence,
disk_usage,
small_peer_info,
}
});
let ditto = Ditto {
fields,
is_shut_down_able: true,
};
Ok(ditto)
}
}
fn default_root_directory() -> Option<PathBuf> {
std::env::current_exe()
.ok()
.and_then(|abspath| abspath.parent().map(|x| x.to_path_buf()))
}
impl DittoAuthenticator {
pub fn set_expiration_handler<F>(&self, handler: F)
where
F: DittoAuthExpirationHandler,
{
let Some(ditto) = self.ditto_fields.upgrade() else {
#[allow(deprecated)] {
error!("Failed to set expiration handler, Ditto has shut down");
}
return;
};
let login_provider = make_login_provider(self.ditto_fields.clone(), Arc::new(handler));
ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, Some(login_provider));
}
pub fn clear_expiration_handler(&self) {
let Some(ditto) = self.ditto_fields.upgrade() else {
#[allow(deprecated)] {
error!("Failed to clear expiration handler, Ditto has shut down");
}
return;
};
ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, None);
}
}
pub trait DittoAuthExpirationHandler: 'static + Send + Sync {
fn on_expiration(
&self,
ditto: &Ditto,
duration_remaining: Duration,
) -> impl Send + Future<Output = ()>;
}
impl<F> DittoAuthExpirationHandler for F
where
F: 'static + Send + Sync,
F: for<'r> AsyncFn2<&'r Ditto, Duration, Output = (), OutputFuture: Send>,
{
async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
self(ditto, duration_remaining).await
}
}
#[async_trait]
pub(crate) trait DynDittoAuthExpirationHandler: 'static + Send + Sync {
async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration);
}
#[async_trait]
impl<F: DittoAuthExpirationHandler> DynDittoAuthExpirationHandler for F {
async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
self.on_expiration(ditto, duration_remaining).await
}
}
impl DittoAuthExpirationHandler for dyn '_ + DynDittoAuthExpirationHandler {
async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
self.dyn_on_expiration(ditto, duration_remaining).await
}
}
pub(crate) fn make_login_provider(
ditto_fields: Weak<DittoFields>,
auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
) -> repr_c::Box<ffi_sdk::LoginProvider> {
struct LoginProviderCtx {
ditto_fields: Weak<DittoFields>,
auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
}
let login_provider_ctx = Arc::new(LoginProviderCtx {
auth_expiration_handler,
ditto_fields,
});
let ffi_ctx = Arc::as_ptr(&login_provider_ctx) as *mut c_void;
let ffi_retain = Some(extern_c(|ctx: *mut c_void| unsafe {
Arc::<LoginProviderCtx>::increment_strong_count(ctx.cast())
}) as unsafe extern "C" fn(_));
let ffi_release = Some(extern_c(|ctx: *mut c_void| unsafe {
Arc::<LoginProviderCtx>::decrement_strong_count(ctx.cast())
}) as unsafe extern "C" fn(_));
let ffi_handler = extern_c(|ctx: *mut c_void, secs_remaining: c_uint| {
let login_provider_ctx: &LoginProviderCtx = unsafe { &*ctx.cast() };
let auth_expiration_handler = login_provider_ctx.auth_expiration_handler.retain();
let Ok(ditto) = Ditto::upgrade(&login_provider_ctx.ditto_fields) else {
#[allow(deprecated)] {
error!("Failed to dispatch auth handler, Ditto has been shut down");
}
return;
};
dispatch_auth_handler(
auth_expiration_handler,
ditto,
Duration::from_secs(secs_remaining.into()),
);
});
unsafe {
ffi_sdk::ditto_auth_client_make_login_provider(
ffi_ctx,
ffi_retain,
ffi_release,
ffi_handler,
)
}
}
fn dispatch_auth_handler(
auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
ditto: Ditto,
duration_remaining: Duration,
) {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
handle.spawn(async move {
auth_expiration_handler
.on_expiration(&ditto, duration_remaining)
.await;
});
}
Err(_) => {
#[allow(deprecated)] {
warn!(
"No tokio runtime available for expiration handler. Creating temporary \
runtime."
);
}
match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => {
rt.block_on(async move {
auth_expiration_handler
.on_expiration(&ditto, duration_remaining)
.await;
});
}
Err(e) => {
panic!(
"Failed to create tokio runtime for expiration handler: {}. Consider \
running within a tokio runtime context.",
e
);
}
}
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_runtime_detection_behavior() {
println!("Testing runtime detection...");
match tokio::runtime::Handle::try_current() {
Ok(_handle) => {
println!("✓ Tokio runtime detected - handlers will be spawned");
}
Err(_) => {
println!("✗ No tokio runtime - will create temporary runtime");
match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(_rt) => {
println!("✓ Successfully created temporary runtime");
}
Err(e) => {
println!("✗ Failed to create temporary runtime: {}", e);
}
}
}
}
}
#[tokio::test]
async fn test_with_tokio_runtime_available() {
println!("Testing with tokio runtime available...");
match tokio::runtime::Handle::try_current() {
Ok(_handle) => {
println!("✓ Tokio runtime is available for async handlers");
}
Err(_) => {
panic!("Expected tokio runtime to be available in tokio::test");
}
}
println!("✓ Runtime detection works correctly in async context");
}
#[test]
fn test_improved_error_handling() {
let rt_result = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build();
match rt_result {
Ok(rt) => {
println!("✓ Temporary runtime creation works");
rt.block_on(async {
println!("✓ Async work executes successfully in temporary runtime");
});
}
Err(e) => {
println!("✗ Failed to create runtime: {}", e);
}
}
}
}