use crate::error::{AnyListError, Result};
use crate::login::login;
use crate::utils::generate_id;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde_derive::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SavedTokens {
pub(crate) access_token: String,
pub(crate) refresh_token: String,
pub(crate) user_id: String,
pub(crate) is_premium_user: bool,
}
impl SavedTokens {
pub fn new(
access_token: impl Into<String>,
refresh_token: impl Into<String>,
user_id: impl Into<String>,
is_premium_user: bool,
) -> Self {
Self {
access_token: access_token.into(),
refresh_token: refresh_token.into(),
user_id: user_id.into(),
is_premium_user,
}
}
pub fn access_token(&self) -> &str {
&self.access_token
}
pub fn refresh_token(&self) -> &str {
&self.refresh_token
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub fn is_premium_user(&self) -> bool {
self.is_premium_user
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthEvent {
TokensRefreshed,
RefreshFailed(String),
}
#[derive(Clone)]
struct AuthState {
access_token: String,
refresh_token: String,
user_id: String,
is_premium_user: bool,
auto_refresh_enabled: bool,
}
pub struct AnyListClient {
auth: Arc<Mutex<AuthState>>,
auth_event_callback: Option<Arc<dyn Fn(AuthEvent) + Send + Sync>>,
client_identifier: String,
client: reqwest::Client,
}
impl AnyListClient {
pub async fn login(email: &str, password: &str) -> Result<Self> {
let client_identifier = generate_id();
let login_result = login(email, password, &client_identifier)
.await
.map_err(|e| AnyListError::AuthenticationFailed(e.to_string()))?;
let auth = Arc::new(Mutex::new(AuthState {
access_token: login_result.access_token,
refresh_token: login_result.refresh_token,
user_id: login_result.user_id,
is_premium_user: login_result.is_premium_user,
auto_refresh_enabled: true,
}));
Ok(Self {
auth,
auth_event_callback: None,
client_identifier,
client: reqwest::Client::new(),
})
}
pub fn from_tokens(tokens: SavedTokens) -> Result<Self> {
let auth = Arc::new(Mutex::new(AuthState {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
user_id: tokens.user_id,
is_premium_user: tokens.is_premium_user,
auto_refresh_enabled: true,
}));
Ok(Self {
auth,
auth_event_callback: None,
client_identifier: generate_id(),
client: reqwest::Client::new(),
})
}
pub fn export_tokens(&self) -> Result<SavedTokens> {
let auth = self.auth.lock().unwrap();
Ok(SavedTokens {
access_token: auth.access_token.clone(),
refresh_token: auth.refresh_token.clone(),
user_id: auth.user_id.clone(),
is_premium_user: auth.is_premium_user,
})
}
pub fn on_auth_event<F>(mut self, callback: F) -> Self
where
F: Fn(AuthEvent) + Send + Sync + 'static,
{
self.auth_event_callback = Some(Arc::new(callback));
self
}
pub fn disable_auto_refresh(self) -> Self {
let mut auth = self.auth.lock().unwrap();
auth.auto_refresh_enabled = false;
drop(auth);
self
}
pub fn user_id(&self) -> String {
let auth = self.auth.lock().unwrap();
auth.user_id.clone()
}
pub fn is_premium_user(&self) -> bool {
let auth = self.auth.lock().unwrap();
auth.is_premium_user
}
pub fn client_identifier(&self) -> &str {
&self.client_identifier
}
pub fn set_client_identifier(&mut self, id: String) {
self.client_identifier = id;
}
pub async fn start_realtime_sync<F>(
self: &Arc<Self>,
callback: F,
) -> Result<crate::realtime::RealtimeSync>
where
F: Fn(crate::realtime::SyncEvent) + Send + Sync + 'static,
{
let mut sync = crate::realtime::RealtimeSync::new(Arc::clone(self), callback);
sync.connect().await?;
Ok(sync)
}
pub async fn refresh_tokens(&self) -> Result<()> {
let refresh_token = {
let auth = self.auth.lock().unwrap();
auth.refresh_token.clone()
};
let mut headers = HeaderMap::new();
headers.insert("X-AnyLeaf-API-Version", HeaderValue::from_static("3"));
headers.insert(
"X-AnyLeaf-Client-Identifier",
HeaderValue::from_str(&self.client_identifier).unwrap(),
);
let form = reqwest::multipart::Form::new().text("refresh_token", refresh_token);
let response = self
.client
.post("https://www.anylist.com/auth/token/refresh")
.headers(headers)
.multipart(form)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await?;
let error_msg = format!(
"Token refresh failed with status: {}, body: {}",
status, body
);
if let Some(callback) = &self.auth_event_callback {
callback(AuthEvent::RefreshFailed(error_msg.clone()));
}
return Err(AnyListError::AuthenticationFailed(error_msg));
}
#[derive(Deserialize)]
struct RefreshResponse {
access_token: String,
refresh_token: String,
}
let token_response: RefreshResponse = response.json().await?;
{
let mut auth = self.auth.lock().unwrap();
auth.access_token = token_response.access_token;
auth.refresh_token = token_response.refresh_token;
}
if let Some(callback) = &self.auth_event_callback {
callback(AuthEvent::TokensRefreshed);
}
Ok(())
}
fn get_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
let auth = self.auth.lock().unwrap();
let bearer_value = format!("Bearer {}", auth.access_token);
headers.insert(AUTHORIZATION, HeaderValue::from_str(&bearer_value).unwrap());
drop(auth);
headers.insert("X-AnyLeaf-API-Version", HeaderValue::from_static("3"));
headers.insert(
"X-AnyLeaf-Client-Identifier",
HeaderValue::from_str(&self.client_identifier).unwrap(),
);
headers
}
pub(crate) async fn post(&self, endpoint: &str, body: Vec<u8>) -> Result<Vec<u8>> {
self.post_multipart(&format!("/{}", endpoint), "operations", body).await
}
pub(crate) async fn post_multipart(
&self,
endpoint: &str,
field_name: &str,
body: Vec<u8>,
) -> Result<Vec<u8>> {
let url = format!("https://www.anylist.com{}", endpoint);
let field_name_owned = field_name.to_string();
let form = reqwest::multipart::Form::new()
.part(field_name_owned.clone(), reqwest::multipart::Part::bytes(body.clone()));
let response = self
.client
.post(&url)
.headers(self.get_headers())
.multipart(form)
.send()
.await?;
if response.status() == 401 {
let auto_refresh = {
let auth = self.auth.lock().unwrap();
auth.auto_refresh_enabled
};
if auto_refresh {
self.refresh_tokens().await?;
let retry_form = reqwest::multipart::Form::new()
.part(field_name_owned, reqwest::multipart::Part::bytes(body));
let retry_response = self
.client
.post(&url)
.headers(self.get_headers())
.multipart(retry_form)
.send()
.await?;
if !retry_response.status().is_success() {
return Err(AnyListError::NetworkError(format!(
"Request failed after token refresh with status: {}",
retry_response.status()
)));
}
let bytes = retry_response.bytes().await?;
return Ok(bytes.to_vec());
} else {
return Err(AnyListError::AuthenticationFailed(
"Unauthorized (auto-refresh disabled)".to_string(),
));
}
}
if !response.status().is_success() {
return Err(AnyListError::NetworkError(format!(
"Request failed with status: {}",
response.status()
)));
}
let bytes = response.bytes().await?;
Ok(bytes.to_vec())
}
pub(crate) async fn post_multipart_form(
&self,
endpoint: &str,
form: reqwest::multipart::Form,
) -> Result<Vec<u8>> {
let url = format!("https://www.anylist.com{}", endpoint);
let response = self
.client
.post(&url)
.headers(self.get_headers())
.multipart(form)
.send()
.await?;
if response.status() == 401 {
return Err(AnyListError::AuthenticationFailed(
"Unauthorized - please refresh tokens and retry".to_string(),
));
}
if !response.status().is_success() {
return Err(AnyListError::NetworkError(format!(
"Request failed with status: {}",
response.status()
)));
}
let bytes = response.bytes().await?;
Ok(bytes.to_vec())
}
}