use crate::bindings;
use crate::error::{Error, Result};
use std::ffi::{CStr, CString};
use std::os::raw::c_char;
use std::ptr;
use std::sync::Arc;
pub type PasswordCallback = dyn Fn(&str, Option<&str>, &str, &str) -> Option<String> + Send + Sync;
pub type ClientCertCallback = dyn Fn(&str) -> Option<Vec<u8>> + Send + Sync;
pub type ServerCertCallback = dyn Fn(&str, &[u8]) -> bool + Send + Sync;
thread_local! {
static PASSWORD_CALLBACK: std::cell::RefCell<Option<Arc<PasswordCallback>>> =
const { std::cell::RefCell::new(None) };
static CLIENT_CERT_CALLBACK: std::cell::RefCell<Option<Arc<ClientCertCallback>>> =
const { std::cell::RefCell::new(None) };
static SERVER_CERT_CALLBACK: std::cell::RefCell<Option<Arc<ServerCertCallback>>> =
const { std::cell::RefCell::new(None) };
}
pub fn set_password_callback(callback: Option<Box<PasswordCallback>>) -> Result<()> {
let has_callback = callback.is_some();
PASSWORD_CALLBACK.with(|cb| {
*cb.borrow_mut() = callback.map(|c| Arc::from(c));
});
unsafe {
if has_callback {
bindings::cupsSetPasswordCB2(Some(password_callback_wrapper), ptr::null_mut());
} else {
bindings::cupsSetPasswordCB2(None, ptr::null_mut());
}
}
Ok(())
}
pub fn set_client_cert_callback(callback: Option<Box<ClientCertCallback>>) -> Result<()> {
CLIENT_CERT_CALLBACK.with(|cb| {
*cb.borrow_mut() = callback.map(|c| Arc::from(c));
});
Ok(())
}
pub fn set_server_cert_callback(callback: Option<Box<ServerCertCallback>>) -> Result<()> {
SERVER_CERT_CALLBACK.with(|cb| {
*cb.borrow_mut() = callback.map(|c| Arc::from(c));
});
Ok(())
}
pub fn get_password(
prompt: &str,
http: Option<&str>,
method: &str,
resource: &str,
) -> Option<String> {
PASSWORD_CALLBACK.with(|cb| {
let callback_ref = cb.borrow();
if let Some(callback) = callback_ref.as_ref() {
callback(prompt, http, method, resource)
} else {
None
}
})
}
pub fn get_client_certificate(server_name: &str) -> Option<Vec<u8>> {
CLIENT_CERT_CALLBACK.with(|cb| {
let callback_ref = cb.borrow();
if let Some(callback) = callback_ref.as_ref() {
callback(server_name)
} else {
None
}
})
}
pub fn validate_server_certificate(server_name: &str, certificate: &[u8]) -> bool {
SERVER_CERT_CALLBACK.with(|cb| {
let callback_ref = cb.borrow();
if let Some(callback) = callback_ref.as_ref() {
callback(server_name, certificate)
} else {
false }
})
}
pub fn do_authentication(
_http_connection: Option<&str>,
method: &str,
resource: &str,
) -> Result<()> {
let method_c = CString::new(method)?;
let resource_c = CString::new(resource)?;
let result = unsafe {
bindings::cupsDoAuthentication(
ptr::null_mut(), method_c.as_ptr(),
resource_c.as_ptr(),
)
};
if result != 0 {
Ok(())
} else {
Err(Error::AuthenticationFailed(format!(
"Authentication failed for {} {}", method, resource
)))
}
}
extern "C" fn password_callback_wrapper(
prompt: *const c_char,
_http: *mut bindings::_http_s,
method: *const c_char,
resource: *const c_char,
_user_data: *mut std::os::raw::c_void,
) -> *const c_char {
let prompt_str = if prompt.is_null() {
""
} else {
unsafe { CStr::from_ptr(prompt).to_str().unwrap_or("") }
};
let method_str = if method.is_null() {
"GET"
} else {
unsafe { CStr::from_ptr(method).to_str().unwrap_or("GET") }
};
let resource_str = if resource.is_null() {
"/"
} else {
unsafe { CStr::from_ptr(resource).to_str().unwrap_or("/") }
};
let password = PASSWORD_CALLBACK.with(|cb| {
let callback_ref = cb.borrow();
if let Some(callback) = callback_ref.as_ref() {
callback(prompt_str, None, method_str, resource_str)
} else {
None
}
});
match password {
Some(pwd) => {
let c_string = CString::new(pwd).unwrap_or_else(|_| CString::new("").unwrap());
let ptr = c_string.into_raw();
ptr
}
None => ptr::null(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_password_callback() {
let result = set_password_callback(Some(Box::new(|_prompt, _http, _method, _resource| {
Some("test_password".to_string())
})));
assert!(result.is_ok());
let password = get_password("Enter password:", None, "GET", "/");
assert_eq!(password, Some("test_password".to_string()));
let result = set_password_callback(None);
assert!(result.is_ok());
let password = get_password("Enter password:", None, "GET", "/");
assert_eq!(password, None);
}
#[test]
fn test_certificate_callbacks() {
let cert_data = vec![1, 2, 3, 4, 5];
let cert_data_clone = cert_data.clone();
let result = set_client_cert_callback(Some(Box::new(move |server_name| {
if server_name == "test.example.com" {
Some(cert_data_clone.clone())
} else {
None
}
})));
assert!(result.is_ok());
let certificate = get_client_certificate("test.example.com");
assert_eq!(certificate, Some(cert_data));
let no_certificate = get_client_certificate("other.example.com");
assert_eq!(no_certificate, None);
let result = set_server_cert_callback(Some(Box::new(|server_name, cert_data| {
server_name == "trusted.example.com" && !cert_data.is_empty()
})));
assert!(result.is_ok());
let valid = validate_server_certificate("trusted.example.com", &[1, 2, 3]);
assert!(valid);
let invalid = validate_server_certificate("untrusted.example.com", &[1, 2, 3]);
assert!(!invalid);
let empty_cert = validate_server_certificate("trusted.example.com", &[]);
assert!(!empty_cert);
let result = set_client_cert_callback(None);
assert!(result.is_ok());
let no_cert = get_client_certificate("test.example.com");
assert_eq!(no_cert, None);
let result = set_server_cert_callback(None);
assert!(result.is_ok());
let no_validation = validate_server_certificate("trusted.example.com", &[1, 2, 3]);
assert!(!no_validation);
}
}