mod builder;
pub mod config;
pub use self::builder::BskyAtpAgentBuilder;
use self::config::Config;
use crate::{
error::Result,
moderation::{
util::interpret_label_value_definitions,
{ModerationPrefsLabeler, Moderator},
},
preference::{FeedViewPreferenceData, Preferences, ThreadViewPreferenceData},
};
use atrium_api::{
agent::Configure,
agent::atp_agent::{
AtpAgent,
store::{AtpSessionStore, MemorySessionStore},
},
app::bsky::actor::defs::PreferencesItem,
types::{Object, Union},
xrpc::XrpcClient,
};
#[cfg(feature = "default-client")]
use atrium_xrpc_client::reqwest::ReqwestClient;
use std::{collections::HashMap, ops::Deref, sync::Arc};
#[cfg(feature = "default-client")]
#[derive(Clone)]
pub struct BskyAgent<T = ReqwestClient, S = MemorySessionStore>
where
T: XrpcClient + Send + Sync,
S: AtpSessionStore + Send + Sync,
S::Error: std::error::Error + Send + Sync + 'static,
{
inner: Arc<AtpAgent<S, T>>,
}
#[cfg(not(feature = "default-client"))]
pub struct BskyAgent<T, S = MemorySessionStore>
where
T: XrpcClient + Send + Sync,
S: AtpSessionStore + Send + Sync,
S::Error: std::error::Error + Send + Sync + 'static,
{
inner: Arc<AtpAgent<S, T>>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "default-client")))]
#[cfg(feature = "default-client")]
impl BskyAgent {
pub fn builder() -> BskyAtpAgentBuilder<ReqwestClient, MemorySessionStore> {
BskyAtpAgentBuilder::default()
}
}
impl<T, S> BskyAgent<T, S>
where
T: XrpcClient + Send + Sync,
S: AtpSessionStore + Send + Sync,
S::Error: std::error::Error + Send + Sync + 'static,
{
pub async fn to_config(&self) -> Config {
Config {
endpoint: self.get_endpoint().await,
session: self.get_session().await,
labelers_header: self.get_labelers_header().await,
proxy_header: self.get_proxy_header().await,
}
}
pub async fn get_preferences(&self, enable_bsky_labeler: bool) -> Result<Preferences> {
let mut prefs = Preferences::default();
if enable_bsky_labeler {
prefs.moderation_prefs.labelers.push(ModerationPrefsLabeler::default());
}
let mut label_prefs = Vec::new();
for pref in self
.api
.app
.bsky
.actor
.get_preferences(
atrium_api::app::bsky::actor::get_preferences::ParametersData {}.into(),
)
.await?
.data
.preferences
{
match pref {
Union::Refs(PreferencesItem::AdultContentPref(p)) => {
prefs.moderation_prefs.adult_content_enabled = p.enabled;
}
Union::Refs(PreferencesItem::ContentLabelPref(p)) => {
label_prefs.push(p);
}
Union::Refs(PreferencesItem::SavedFeedsPrefV2(p)) => {
prefs.saved_feeds = p.data.items;
}
Union::Refs(PreferencesItem::FeedViewPref(p)) => {
let mut pref = FeedViewPreferenceData::default();
if let Some(v) = p.hide_replies {
pref.hide_replies = v;
}
if let Some(v) = p.hide_replies_by_unfollowed {
pref.hide_replies_by_unfollowed = v;
}
if let Some(v) = p.hide_replies_by_like_count {
pref.hide_replies_by_like_count = v;
}
if let Some(v) = p.hide_reposts {
pref.hide_reposts = v;
}
if let Some(v) = p.hide_quote_posts {
pref.hide_quote_posts = v;
}
prefs.feed_view_prefs.insert(
p.data.feed,
Object {
data: pref,
extra_data: p.extra_data, },
);
}
Union::Refs(PreferencesItem::ThreadViewPref(p)) => {
let mut pref = ThreadViewPreferenceData::default();
if let Some(v) = &p.sort {
pref.sort = v.clone();
}
if let Some(v) = p.prioritize_followed_users {
pref.prioritize_followed_users = v;
}
prefs.thread_view_prefs = Object {
data: pref,
extra_data: p.extra_data, };
}
Union::Refs(PreferencesItem::MutedWordsPref(p)) => {
prefs.moderation_prefs.muted_words = p.data.items;
}
Union::Refs(PreferencesItem::HiddenPostsPref(p)) => {
prefs.moderation_prefs.hidden_posts = p.data.items;
}
Union::Refs(PreferencesItem::LabelersPref(p)) => {
prefs.moderation_prefs.labelers.extend(p.data.labelers.into_iter().map(
|item| ModerationPrefsLabeler {
did: item.data.did,
labels: HashMap::default(),
is_default_labeler: false,
},
));
}
_ => {
}
}
}
for pref in label_prefs {
if let Some(did) = pref.data.labeler_did {
if let Some(l) = prefs.moderation_prefs.labelers.iter_mut().find(|l| l.did == did) {
l.labels.insert(
pref.data.label,
pref.data.visibility.parse().expect("invalid visibility"),
);
}
} else {
prefs.moderation_prefs.labels.insert(
pref.data.label,
pref.data.visibility.parse().expect("invalid visibility"),
);
}
}
Ok(prefs)
}
pub fn configure_labelers_from_preferences(&self, preferences: &Preferences) {
self.configure_labelers_header(Some(
preferences
.moderation_prefs
.labelers
.iter()
.map(|labeler| (labeler.did.clone(), labeler.is_default_labeler))
.collect(),
));
}
pub async fn moderator(&self, preferences: &Preferences) -> Result<Moderator> {
let views = if preferences.moderation_prefs.labelers.is_empty() {
Vec::new()
} else {
self.api
.app
.bsky
.labeler
.get_services(
atrium_api::app::bsky::labeler::get_services::ParametersData {
detailed: Some(true),
dids: preferences
.moderation_prefs
.labelers
.iter()
.map(|labeler| labeler.did.clone())
.collect(),
}
.into(),
)
.await?
.data
.views
};
let mut label_defs = HashMap::with_capacity(views.len());
for labeler in &views {
let Union::Refs(atrium_api::app::bsky::labeler::get_services::OutputViewsItem::AppBskyLabelerDefsLabelerViewDetailed(labeler_view)) = labeler else {
continue;
};
label_defs.insert(
labeler_view.creator.did.clone(),
interpret_label_value_definitions(labeler_view)?,
);
}
Ok(Moderator::new(
self.get_session().await.map(|s| s.data.did),
preferences.moderation_prefs.clone(),
label_defs,
))
}
}
impl<T, S> Deref for BskyAgent<T, S>
where
T: XrpcClient + Send + Sync,
S: AtpSessionStore + Send + Sync,
S::Error: std::error::Error + Send + Sync + 'static,
{
type Target = AtpAgent<S, T>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use atrium_api::{
agent::{AuthorizationProvider, atp_agent::AtpSession},
com::atproto::server::create_session::OutputData,
xrpc::types::AuthorizationToken,
};
use atrium_common::store::Store;
use thiserror::Error;
#[derive(Error, Debug)]
#[error("mock session store error")]
pub struct MockSessionError;
#[derive(Clone)]
pub struct MockSessionStore;
impl Store<(), AtpSession> for MockSessionStore {
type Error = MockSessionError;
async fn get(&self, _: &()) -> core::result::Result<Option<AtpSession>, Self::Error> {
Ok(Some(
OutputData {
access_jwt: String::from("access"),
active: None,
did: "did:fake:handle.test".parse().expect("invalid did"),
did_doc: None,
email: None,
email_auth_factor: None,
email_confirmed: None,
handle: "handle.test".parse().expect("invalid handle"),
refresh_jwt: String::from("refresh"),
status: None,
}
.into(),
))
}
async fn set(&self, _: (), _: AtpSession) -> core::result::Result<(), Self::Error> {
unimplemented!()
}
async fn del(&self, _: &()) -> core::result::Result<(), Self::Error> {
unimplemented!()
}
async fn clear(&self) -> core::result::Result<(), Self::Error> {
unimplemented!()
}
}
impl AuthorizationProvider for MockSessionStore {
async fn authorization_token(&self, _: bool) -> Option<AuthorizationToken> {
Some(AuthorizationToken::Bearer(String::from("access")))
}
}
impl AtpSessionStore for MockSessionStore {}
#[cfg(feature = "default-client")]
#[tokio::test]
async fn clone_agent() {
let agent = BskyAgent::builder()
.store(MockSessionStore)
.build()
.await
.expect("failed to build agent");
let cloned = agent.clone();
agent.configure_endpoint(String::from("https://example.com"));
assert_eq!(cloned.get_endpoint().await, "https://example.com");
}
}