floe_core/io/storage/providers/
adls.rs1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use azure_core::prelude::IfMatchCondition;
5use azure_identity::{DefaultAzureCredential, TokenCredentialOptions};
6use azure_storage::StorageCredentials;
7use azure_storage_blobs::prelude::{BlobServiceClient, ContainerClient};
8use futures::StreamExt;
9use tokio::runtime::Runtime;
10
11use crate::errors::StorageError;
12use crate::io::storage::{
13 planner, uri, validation, ConditionalWrite, ObjectRef, StorageClient, StoredObject,
14};
15use crate::{config, FloeResult};
16
17pub struct AdlsClient {
18 account: String,
19 container: String,
20 prefix: String,
21 runtime: Runtime,
22 container_client: ContainerClient,
23}
24
25impl AdlsClient {
26 pub fn new(definition: &config::StorageDefinition) -> FloeResult<Self> {
27 let account =
28 validation::require_field(definition, definition.account.as_ref(), "account", "adls")?;
29 let container = validation::require_field(
30 definition,
31 definition.container.as_ref(),
32 "container",
33 "adls",
34 )?;
35 let prefix = definition.prefix.clone().unwrap_or_default();
36 let runtime = tokio::runtime::Builder::new_current_thread()
37 .enable_all()
38 .build()
39 .map_err(|err| Box::new(StorageError(format!("adls runtime init failed: {err}"))))?;
40 let credential = DefaultAzureCredential::create(TokenCredentialOptions::default())
41 .map_err(|err| Box::new(StorageError(format!("adls credential init failed: {err}"))))?;
42 let storage_credentials = StorageCredentials::token_credential(Arc::new(credential));
43 let service_client = BlobServiceClient::new(account.clone(), storage_credentials);
44 let container_client = service_client.container_client(container.clone());
45 Ok(Self {
46 account,
47 container,
48 prefix,
49 runtime,
50 container_client,
51 })
52 }
53
54 fn base_prefix(&self) -> String {
55 planner::normalize_separators(&self.prefix)
56 }
57
58 fn full_path(&self, path: &str) -> String {
59 let prefix = self.base_prefix();
60 let joined = planner::join_prefix(&prefix, &planner::normalize_separators(path));
61 joined.trim_start_matches('/').to_string()
62 }
63
64 fn format_abfs(&self, path: &str) -> String {
65 format_abfs_uri(&self.container, &self.account, path)
66 }
67}
68
69impl StorageClient for AdlsClient {
70 fn list(&self, prefix_or_path: &str) -> FloeResult<Vec<ObjectRef>> {
71 let prefix = self.full_path(prefix_or_path);
72 let container = self.container.clone();
73 let account = self.account.clone();
74 let client = self.container_client.clone();
75 self.runtime.block_on(async move {
76 let mut refs = Vec::new();
77 let mut stream = client.list_blobs().prefix(prefix.clone()).into_stream();
78 while let Some(resp) = stream.next().await {
79 let resp = resp.map_err(|err| {
80 Box::new(StorageError(format!("adls list failed: {err}")))
81 as Box<dyn std::error::Error + Send + Sync>
82 })?;
83 for blob in resp.blobs.blobs() {
84 let key = blob.name.clone();
85 let uri = if key.is_empty() {
86 format!("abfs://{}@{}.dfs.core.windows.net", container, account)
87 } else {
88 format!(
89 "abfs://{}@{}.dfs.core.windows.net/{}",
90 container, account, key
91 )
92 };
93 refs.push(planner::object_ref(
94 uri,
95 key,
96 Some(blob.properties.last_modified.to_string()),
97 Some(blob.properties.content_length),
98 ));
99 }
100 }
101 Ok(planner::stable_sort_refs(refs))
102 })
103 }
104
105 fn download_to_temp(&self, uri: &str, temp_dir: &Path) -> FloeResult<PathBuf> {
106 let key = uri
107 .split_once(".dfs.core.windows.net/")
108 .map(|(_, tail)| tail)
109 .unwrap_or("")
110 .trim_start_matches('/')
111 .to_string();
112 let key = if key.is_empty() {
113 return Err(Box::new(StorageError(
114 "adls download requires a blob path".to_string(),
115 )));
116 } else {
117 key
118 };
119 let dest = planner::temp_path_for_key(temp_dir, &key);
120 let dest_clone = dest.clone();
121 let client = self.container_client.clone();
122 let key_clone = key.clone();
123 self.runtime.block_on(async move {
124 if let Some(parent) = dest_clone.parent() {
125 tokio::fs::create_dir_all(parent).await?;
126 }
127 let blob = client.blob_client(key_clone);
128 let mut stream = blob.get().into_stream();
129 let mut file = tokio::fs::File::create(&dest_clone).await?;
130 while let Some(chunk) = stream.next().await {
131 let resp = chunk.map_err(|err| {
132 Box::new(StorageError(format!("adls download failed: {err}")))
133 as Box<dyn std::error::Error + Send + Sync>
134 })?;
135 let bytes = resp.data.collect().await.map_err(|err| {
136 Box::new(StorageError(format!("adls download read failed: {err}")))
137 as Box<dyn std::error::Error + Send + Sync>
138 })?;
139 tokio::io::AsyncWriteExt::write_all(&mut file, &bytes).await?;
140 }
141 Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
142 })?;
143 Ok(dest)
144 }
145
146 fn upload_from_path(&self, local_path: &Path, uri: &str) -> FloeResult<()> {
147 let key = uri
148 .split_once(".dfs.core.windows.net/")
149 .map(|(_, tail)| tail)
150 .unwrap_or("")
151 .trim_start_matches('/')
152 .to_string();
153 if key.is_empty() {
154 return Err(Box::new(StorageError(
155 "adls upload requires a blob path".to_string(),
156 )));
157 }
158 let client = self.container_client.clone();
159 let path = local_path.to_path_buf();
160 self.runtime.block_on(async move {
161 let data = tokio::fs::read(path).await?;
162 let blob = client.blob_client(key);
163 blob.put_block_blob(data)
164 .content_type("application/octet-stream")
165 .into_future()
166 .await
167 .map_err(|err| {
168 Box::new(StorageError(format!("adls upload failed: {err}")))
169 as Box<dyn std::error::Error + Send + Sync>
170 })?;
171 Ok(())
172 })
173 }
174
175 fn resolve_uri(&self, path: &str) -> FloeResult<String> {
176 Ok(self.format_abfs(&self.full_path(path)))
177 }
178
179 fn copy_object(&self, src_uri: &str, dst_uri: &str) -> FloeResult<()> {
180 planner::copy_via_temp(self, src_uri, dst_uri)
181 }
182
183 fn delete_object(&self, uri: &str) -> FloeResult<()> {
184 let key = uri
185 .split_once(".dfs.core.windows.net/")
186 .map(|(_, tail)| tail)
187 .unwrap_or("")
188 .trim_start_matches('/')
189 .to_string();
190 if key.is_empty() {
191 return Ok(());
192 }
193 let client = self.container_client.clone();
194 self.runtime.block_on(async move {
195 let blob = client.blob_client(key);
196 blob.delete().into_future().await.map_err(|err| {
197 Box::new(StorageError(format!("adls delete failed: {err}")))
198 as Box<dyn std::error::Error + Send + Sync>
199 })?;
200 Ok(())
201 })
202 }
203
204 fn exists(&self, uri: &str) -> FloeResult<bool> {
205 let key = uri
206 .split_once(".dfs.core.windows.net/")
207 .map(|(_, tail)| tail)
208 .unwrap_or("")
209 .trim_start_matches('/')
210 .to_string();
211 planner::exists_by_key(self, &key)
212 }
213
214 fn read_object(&self, uri: &str) -> FloeResult<Option<StoredObject>> {
215 let key = adls_key_from_uri(uri)?;
216 let client = self.container_client.clone();
217 self.runtime.block_on(async move {
218 let blob = client.blob_client(key);
219 let mut stream = blob.get().into_stream();
220 let mut body = Vec::new();
221 let mut version = None;
222 while let Some(chunk) = stream.next().await {
223 let resp = match chunk {
224 Ok(resp) => resp,
225 Err(err) if is_not_found(&err) => return Ok(None),
226 Err(err) => {
227 return Err(
228 Box::new(StorageError(format!("adls download failed: {err}")))
229 as Box<dyn std::error::Error + Send + Sync>,
230 )
231 }
232 };
233 if version.is_none() {
234 version = Some(resp.blob.properties.etag.to_string());
235 }
236 let bytes = resp.data.collect().await.map_err(|err| {
237 Box::new(StorageError(format!("adls download read failed: {err}")))
238 as Box<dyn std::error::Error + Send + Sync>
239 })?;
240 body.extend_from_slice(&bytes);
241 }
242 Ok(version.map(|version| StoredObject { body, version }))
243 })
244 }
245
246 fn write_object_conditional(
247 &self,
248 uri: &str,
249 expected_version: Option<&str>,
250 body: &[u8],
251 ) -> FloeResult<ConditionalWrite> {
252 let key = adls_key_from_uri(uri)?;
253 let client = self.container_client.clone();
254 let body = body.to_vec();
255 self.runtime.block_on(async move {
256 let condition = expected_version
257 .map(|version| IfMatchCondition::Match(version.to_string()))
258 .unwrap_or_else(|| IfMatchCondition::NotMatch("*".to_string()));
259 match client
260 .blob_client(key)
261 .put_block_blob(body)
262 .if_match(condition)
263 .content_type("application/json")
264 .into_future()
265 .await
266 {
267 Ok(resp) => Ok(ConditionalWrite::Written { version: resp.etag }),
268 Err(err) if is_precondition(&err) => Ok(ConditionalWrite::Conflict),
269 Err(err) => Err(Box::new(StorageError(format!("adls upload failed: {err}")))
270 as Box<dyn std::error::Error + Send + Sync>),
271 }
272 })
273 }
274
275 fn delete_object_conditional(
276 &self,
277 uri: &str,
278 expected_version: Option<&str>,
279 ) -> FloeResult<ConditionalWrite> {
280 let Some(expected_version) = expected_version else {
281 return Ok(ConditionalWrite::Written {
282 version: "deleted".to_string(),
283 });
284 };
285 let key = adls_key_from_uri(uri)?;
286 let client = self.container_client.clone();
287 self.runtime.block_on(async move {
288 let request = client
289 .blob_client(key)
290 .delete()
291 .if_match(IfMatchCondition::Match(expected_version.to_string()));
292 match request.into_future().await {
293 Ok(_) => Ok(ConditionalWrite::Written {
294 version: "deleted".to_string(),
295 }),
296 Err(err) if is_precondition(&err) => Ok(ConditionalWrite::Conflict),
297 Err(err) if is_not_found(&err) => Ok(ConditionalWrite::Written {
298 version: "deleted".to_string(),
299 }),
300 Err(err) => Err(Box::new(StorageError(format!("adls delete failed: {err}")))
301 as Box<dyn std::error::Error + Send + Sync>),
302 }
303 })
304 }
305}
306
307fn adls_key_from_uri(uri: &str) -> FloeResult<String> {
308 let key = uri
309 .split_once(".dfs.core.windows.net/")
310 .map(|(_, tail)| tail)
311 .unwrap_or("")
312 .trim_start_matches('/')
313 .to_string();
314 if key.is_empty() {
315 return Err(Box::new(StorageError(
316 "adls state operation requires a blob path".to_string(),
317 )));
318 }
319 Ok(key)
320}
321
322fn is_not_found<E: std::fmt::Display>(err: &E) -> bool {
323 let text = err.to_string();
324 text.contains("404") || text.contains("NotFound")
325}
326
327fn is_precondition<E: std::fmt::Display>(err: &E) -> bool {
328 let text = err.to_string();
329 text.contains("412") || text.contains("condition") || text.contains("Condition")
330}
331
332pub fn parse_adls_uri(uri: &str) -> FloeResult<AdlsLocation> {
333 uri::parse_abfs_uri(uri)
334}
335
336pub fn format_abfs_uri(container: &str, account: &str, path: &str) -> String {
337 uri::format_abfs_uri(container, account, path)
338}
339
340pub type AdlsLocation = uri::AdlsLocation;