use super::LabelerDefs;
use crate::client::{AgentError, AgentSessionExt, CollectionErr, CollectionOutput};
use crate::moderation::labeled::LabeledRecord;
#[cfg(feature = "api_bluesky")]
use jacquard_api::app_bsky::labeler::{
get_services::{GetServices, GetServicesOutput},
service::Service,
};
use jacquard_api::com_atproto::label::{Label, query_labels::QueryLabels};
use jacquard_common::cowstr::ToCowStr;
use jacquard_common::error::ClientError;
use jacquard_common::types::collection::Collection;
use jacquard_common::types::string::Did;
use jacquard_common::types::uri::RecordUri;
use jacquard_common::xrpc::{XrpcClient, XrpcError};
use jacquard_common::{CowStr, IntoStatic};
use std::convert::From;
#[cfg(feature = "api_bluesky")]
pub async fn fetch_labeler_defs(
client: &(impl XrpcClient + Sync),
dids: Vec<Did<'_>>,
) -> Result<LabelerDefs<'static>, ClientError> {
#[cfg(feature = "tracing")]
let _span = tracing::debug_span!("fetch_labeler_defs", count = dids.len()).entered();
let request = GetServices::new().dids(dids).detailed(true).build();
let response = client.send(request).await?;
let output: GetServicesOutput<'static> = response.into_output().map_err(|e| match e {
XrpcError::Auth(auth) => ClientError::auth(auth),
XrpcError::Generic(g) => ClientError::decode(g.to_string()),
XrpcError::Decode(e) => ClientError::decode(format!("{:?}", e)),
XrpcError::Xrpc(typed) => ClientError::decode(format!("{:?}", typed)),
_ => ClientError::decode("unknown XRPC error"),
})?;
let mut defs = LabelerDefs::new();
use jacquard_api::app_bsky::labeler::get_services::GetServicesOutputViewsItem;
for view in output.views {
match view {
GetServicesOutputViewsItem::LabelerViewDetailed(detailed) => {
if let Some(label_value_definitions) = &detailed.policies.label_value_definitions {
defs.insert(
detailed.creator.did.clone().into_static(),
label_value_definitions
.iter()
.map(|d| d.clone().into_static())
.collect(),
);
}
}
_ => {
continue;
}
}
}
Ok(defs)
}
#[cfg(feature = "api_bluesky")]
pub async fn fetch_labeler_defs_direct(
client: &(impl AgentSessionExt + Sync),
dids: Vec<Did<'_>>,
) -> Result<LabelerDefs<'static>, AgentError> {
#[cfg(feature = "tracing")]
let _span = tracing::debug_span!("fetch_labeler_defs_direct", count = dids.len()).entered();
let mut defs = LabelerDefs::new();
for did in dids {
let uri = format!("at://{}/app.bsky.labeler.service/self", did.as_str());
let record_uri = Service::uri(uri).map_err(|e| {
AgentError::from(ClientError::invalid_request(format!("Invalid URI: {}", e)))
})?;
let output = client.fetch_record(&record_uri).await?;
let service: Service<'static> = output.value;
if let Some(label_value_definitions) = service.policies.label_value_definitions {
defs.insert(did.into_static(), label_value_definitions);
}
}
Ok(defs)
}
pub async fn fetch_labels(
client: &impl AgentSessionExt,
uri_patterns: Vec<CowStr<'_>>,
sources: Vec<Did<'_>>,
cursor: Option<CowStr<'_>>,
) -> Result<(Vec<Label<'static>>, Option<CowStr<'static>>), AgentError> {
#[cfg(feature = "tracing")]
let _span = tracing::debug_span!("fetch_labels", count = sources.len()).entered();
let request = QueryLabels::new()
.maybe_cursor(cursor)
.limit(250)
.uri_patterns(uri_patterns)
.sources(sources)
.build();
let labels = client
.send(request)
.await?
.into_output()
.map_err(|e| match e {
XrpcError::Auth(auth) => AgentError::from(auth),
XrpcError::Xrpc(typed) => AgentError::xrpc(XrpcError::Xrpc(typed)),
e => AgentError::xrpc(e),
})?;
Ok((labels.labels, labels.cursor))
}
pub async fn fetch_labeled_record<R>(
client: &impl AgentSessionExt,
record_uri: &RecordUri<'_, R>,
sources: Vec<Did<'_>>,
) -> Result<LabeledRecord<'static, R>, AgentError>
where
R: Collection + From<CollectionOutput<'static, R>>,
for<'a> CollectionOutput<'a, R>: IntoStatic<Output = CollectionOutput<'static, R>>,
for<'a> CollectionErr<'a, R>: IntoStatic<Output = CollectionErr<'static, R>> + Send + Sync,
{
let record: R = client.fetch_record(record_uri).await?.into();
let (labels, _) =
fetch_labels(client, vec![record_uri.as_uri().to_cowstr()], sources, None).await?;
Ok(LabeledRecord { record, labels })
}