bsky_sdk/agent/
builder.rs1use super::config::Config;
2use super::BskyAgent;
3use crate::error::Result;
4use atrium_api::agent::{
5 atp_agent::{
6 store::{AtpSessionStore, MemorySessionStore},
7 AtpAgent,
8 },
9 Configure,
10};
11use atrium_api::xrpc::XrpcClient;
12#[cfg(feature = "default-client")]
13use atrium_xrpc_client::reqwest::ReqwestClient;
14use std::sync::Arc;
15
16pub struct BskyAtpAgentBuilder<T, S = MemorySessionStore>
18where
19 T: XrpcClient + Send + Sync,
20 S: AtpSessionStore + Send + Sync,
21 S::Error: std::error::Error + Send + Sync + 'static,
22{
23 config: Config,
24 store: S,
25 client: T,
26}
27
28impl<T> BskyAtpAgentBuilder<T>
29where
30 T: XrpcClient + Send + Sync,
31{
32 pub fn new(client: T) -> Self {
34 Self { config: Config::default(), store: MemorySessionStore::default(), client }
35 }
36}
37
38impl<T, S> BskyAtpAgentBuilder<T, S>
39where
40 T: XrpcClient + Send + Sync,
41 S: AtpSessionStore + Send + Sync,
42 S::Error: std::error::Error + Send + Sync + 'static,
43{
44 pub fn config(mut self, config: Config) -> Self {
46 self.config = config;
47 self
48 }
49 pub fn store<S0>(self, store: S0) -> BskyAtpAgentBuilder<T, S0>
53 where
54 S0: AtpSessionStore + Send + Sync,
55 S0::Error: std::error::Error + Send + Sync + 'static,
56 {
57 BskyAtpAgentBuilder { config: self.config, store, client: self.client }
58 }
59 pub fn client<T0>(self, client: T0) -> BskyAtpAgentBuilder<T0, S>
63 where
64 T0: XrpcClient + Send + Sync,
65 {
66 BskyAtpAgentBuilder { config: self.config, store: self.store, client }
67 }
68 pub async fn build(self) -> Result<BskyAgent<T, S>> {
69 let agent = AtpAgent::new(self.client, self.store);
70 agent.configure_endpoint(self.config.endpoint);
71 if let Some(session) = self.config.session {
72 agent.resume_session(session).await?;
73 }
74 if let Some(labelers) = self.config.labelers_header {
75 agent.configure_labelers_header(Some(
76 labelers
77 .iter()
78 .filter_map(|did| {
79 let (did, redact) = match did.split_once(';') {
80 Some((did, params)) if params.trim() == "redact" => (did, true),
81 None => (did.as_str(), false),
82 _ => return None,
83 };
84 did.parse().ok().map(|did| (did, redact))
85 })
86 .collect(),
87 ));
88 }
89 if let Some(proxy) = self.config.proxy_header {
90 if let Some((did, service_type)) = proxy.split_once('#') {
91 if let Ok(did) = did.parse() {
92 agent.configure_proxy_header(did, service_type);
93 }
94 }
95 }
96 Ok(BskyAgent { inner: Arc::new(agent) })
97 }
98}
99
100#[cfg_attr(docsrs, doc(cfg(feature = "default-client")))]
101#[cfg(feature = "default-client")]
102impl Default for BskyAtpAgentBuilder<ReqwestClient, MemorySessionStore> {
103 fn default() -> Self {
107 Self::new(ReqwestClient::new(Config::default().endpoint))
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::agent::tests::MockSessionStore;
115
116 #[cfg(feature = "default-client")]
117 #[tokio::test]
118 async fn default() -> Result<()> {
119 {
121 let agent = BskyAtpAgentBuilder::default().build().await?;
122 assert_eq!(agent.get_endpoint().await, "https://bsky.social");
123 assert_eq!(agent.get_session().await, None);
124 }
125 {
127 let agent = BskyAtpAgentBuilder::default().store(MockSessionStore).build().await?;
128 assert_eq!(agent.get_endpoint().await, "https://bsky.social");
129 assert_eq!(
130 agent.get_session().await.map(|session| session.data.handle),
131 Some("handle.test".parse().expect("invalid handle"))
132 );
133 }
134 {
136 let agent = BskyAtpAgentBuilder::default()
137 .config(Config {
138 endpoint: "https://example.com".to_string(),
139 ..Default::default()
140 })
141 .build()
142 .await?;
143 assert_eq!(agent.get_endpoint().await, "https://example.com");
144 assert_eq!(agent.get_session().await, None);
145 }
146 Ok(())
147 }
148
149 #[cfg(not(feature = "default-client"))]
150 #[tokio::test]
151 async fn custom() -> Result<()> {
152 use crate::tests::MockClient;
153
154 {
156 let agent = BskyAtpAgentBuilder::new(MockClient).build().await?;
157 assert_eq!(agent.get_endpoint().await, "https://bsky.social");
158 }
159 {
161 let agent =
162 BskyAtpAgentBuilder::new(MockClient).store(MockSessionStore).build().await?;
163 assert_eq!(agent.get_endpoint().await, "https://bsky.social");
164 assert_eq!(
165 agent.get_session().await.map(|session| session.data.handle),
166 Some("handle.test".parse().expect("invalid handle"))
167 );
168 }
169 {
171 let agent = BskyAtpAgentBuilder::new(MockClient)
172 .config(Config {
173 endpoint: "https://example.com".to_string(),
174 ..Default::default()
175 })
176 .build()
177 .await?;
178 assert_eq!(agent.get_endpoint().await, "https://example.com");
179 assert_eq!(agent.get_session().await, None);
180 }
181 Ok(())
182 }
183}