Skip to main content

floe_core/io/storage/providers/
adls.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use azure_core::prelude::IfMatchCondition;
5use azure_identity::{DefaultAzureCredential, TokenCredentialOptions};
6use azure_storage::StorageCredentials;
7use azure_storage_blobs::prelude::{BlobServiceClient, ContainerClient};
8use futures::StreamExt;
9use tokio::runtime::Runtime;
10
11use crate::errors::StorageError;
12use crate::io::storage::{
13    planner, uri, validation, ConditionalWrite, ObjectRef, StorageClient, StoredObject,
14};
15use crate::{config, FloeResult};
16
17pub struct AdlsClient {
18    account: String,
19    container: String,
20    prefix: String,
21    runtime: Runtime,
22    container_client: ContainerClient,
23}
24
25impl AdlsClient {
26    pub fn new(definition: &config::StorageDefinition) -> FloeResult<Self> {
27        let account =
28            validation::require_field(definition, definition.account.as_ref(), "account", "adls")?;
29        let container = validation::require_field(
30            definition,
31            definition.container.as_ref(),
32            "container",
33            "adls",
34        )?;
35        let prefix = definition.prefix.clone().unwrap_or_default();
36        let runtime = tokio::runtime::Builder::new_current_thread()
37            .enable_all()
38            .build()
39            .map_err(|err| Box::new(StorageError(format!("adls runtime init failed: {err}"))))?;
40        let credential = DefaultAzureCredential::create(TokenCredentialOptions::default())
41            .map_err(|err| Box::new(StorageError(format!("adls credential init failed: {err}"))))?;
42        let storage_credentials = StorageCredentials::token_credential(Arc::new(credential));
43        let service_client = BlobServiceClient::new(account.clone(), storage_credentials);
44        let container_client = service_client.container_client(container.clone());
45        Ok(Self {
46            account,
47            container,
48            prefix,
49            runtime,
50            container_client,
51        })
52    }
53
54    fn base_prefix(&self) -> String {
55        planner::normalize_separators(&self.prefix)
56    }
57
58    fn full_path(&self, path: &str) -> String {
59        let prefix = self.base_prefix();
60        let joined = planner::join_prefix(&prefix, &planner::normalize_separators(path));
61        joined.trim_start_matches('/').to_string()
62    }
63
64    fn format_abfs(&self, path: &str) -> String {
65        format_abfs_uri(&self.container, &self.account, path)
66    }
67}
68
69impl StorageClient for AdlsClient {
70    fn list(&self, prefix_or_path: &str) -> FloeResult<Vec<ObjectRef>> {
71        let prefix = self.full_path(prefix_or_path);
72        let container = self.container.clone();
73        let account = self.account.clone();
74        let client = self.container_client.clone();
75        self.runtime.block_on(async move {
76            let mut refs = Vec::new();
77            let mut stream = client.list_blobs().prefix(prefix.clone()).into_stream();
78            while let Some(resp) = stream.next().await {
79                let resp = resp.map_err(|err| {
80                    Box::new(StorageError(format!("adls list failed: {err}")))
81                        as Box<dyn std::error::Error + Send + Sync>
82                })?;
83                for blob in resp.blobs.blobs() {
84                    let key = blob.name.clone();
85                    let uri = if key.is_empty() {
86                        format!("abfs://{}@{}.dfs.core.windows.net", container, account)
87                    } else {
88                        format!(
89                            "abfs://{}@{}.dfs.core.windows.net/{}",
90                            container, account, key
91                        )
92                    };
93                    refs.push(planner::object_ref(
94                        uri,
95                        key,
96                        Some(blob.properties.last_modified.to_string()),
97                        Some(blob.properties.content_length),
98                    ));
99                }
100            }
101            Ok(planner::stable_sort_refs(refs))
102        })
103    }
104
105    fn download_to_temp(&self, uri: &str, temp_dir: &Path) -> FloeResult<PathBuf> {
106        let key = uri
107            .split_once(".dfs.core.windows.net/")
108            .map(|(_, tail)| tail)
109            .unwrap_or("")
110            .trim_start_matches('/')
111            .to_string();
112        let key = if key.is_empty() {
113            return Err(Box::new(StorageError(
114                "adls download requires a blob path".to_string(),
115            )));
116        } else {
117            key
118        };
119        let dest = planner::temp_path_for_key(temp_dir, &key);
120        let dest_clone = dest.clone();
121        let client = self.container_client.clone();
122        let key_clone = key.clone();
123        self.runtime.block_on(async move {
124            if let Some(parent) = dest_clone.parent() {
125                tokio::fs::create_dir_all(parent).await?;
126            }
127            let blob = client.blob_client(key_clone);
128            let mut stream = blob.get().into_stream();
129            let mut file = tokio::fs::File::create(&dest_clone).await?;
130            while let Some(chunk) = stream.next().await {
131                let resp = chunk.map_err(|err| {
132                    Box::new(StorageError(format!("adls download failed: {err}")))
133                        as Box<dyn std::error::Error + Send + Sync>
134                })?;
135                let bytes = resp.data.collect().await.map_err(|err| {
136                    Box::new(StorageError(format!("adls download read failed: {err}")))
137                        as Box<dyn std::error::Error + Send + Sync>
138                })?;
139                tokio::io::AsyncWriteExt::write_all(&mut file, &bytes).await?;
140            }
141            Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
142        })?;
143        Ok(dest)
144    }
145
146    fn upload_from_path(&self, local_path: &Path, uri: &str) -> FloeResult<()> {
147        let key = uri
148            .split_once(".dfs.core.windows.net/")
149            .map(|(_, tail)| tail)
150            .unwrap_or("")
151            .trim_start_matches('/')
152            .to_string();
153        if key.is_empty() {
154            return Err(Box::new(StorageError(
155                "adls upload requires a blob path".to_string(),
156            )));
157        }
158        let client = self.container_client.clone();
159        let path = local_path.to_path_buf();
160        self.runtime.block_on(async move {
161            let data = tokio::fs::read(path).await?;
162            let blob = client.blob_client(key);
163            blob.put_block_blob(data)
164                .content_type("application/octet-stream")
165                .into_future()
166                .await
167                .map_err(|err| {
168                    Box::new(StorageError(format!("adls upload failed: {err}")))
169                        as Box<dyn std::error::Error + Send + Sync>
170                })?;
171            Ok(())
172        })
173    }
174
175    fn resolve_uri(&self, path: &str) -> FloeResult<String> {
176        Ok(self.format_abfs(&self.full_path(path)))
177    }
178
179    fn copy_object(&self, src_uri: &str, dst_uri: &str) -> FloeResult<()> {
180        planner::copy_via_temp(self, src_uri, dst_uri)
181    }
182
183    fn delete_object(&self, uri: &str) -> FloeResult<()> {
184        let key = uri
185            .split_once(".dfs.core.windows.net/")
186            .map(|(_, tail)| tail)
187            .unwrap_or("")
188            .trim_start_matches('/')
189            .to_string();
190        if key.is_empty() {
191            return Ok(());
192        }
193        let client = self.container_client.clone();
194        self.runtime.block_on(async move {
195            let blob = client.blob_client(key);
196            blob.delete().into_future().await.map_err(|err| {
197                Box::new(StorageError(format!("adls delete failed: {err}")))
198                    as Box<dyn std::error::Error + Send + Sync>
199            })?;
200            Ok(())
201        })
202    }
203
204    fn exists(&self, uri: &str) -> FloeResult<bool> {
205        let key = uri
206            .split_once(".dfs.core.windows.net/")
207            .map(|(_, tail)| tail)
208            .unwrap_or("")
209            .trim_start_matches('/')
210            .to_string();
211        planner::exists_by_key(self, &key)
212    }
213
214    fn read_object(&self, uri: &str) -> FloeResult<Option<StoredObject>> {
215        let key = adls_key_from_uri(uri)?;
216        let client = self.container_client.clone();
217        self.runtime.block_on(async move {
218            let blob = client.blob_client(key);
219            let mut stream = blob.get().into_stream();
220            let mut body = Vec::new();
221            let mut version = None;
222            while let Some(chunk) = stream.next().await {
223                let resp = match chunk {
224                    Ok(resp) => resp,
225                    Err(err) if is_not_found(&err) => return Ok(None),
226                    Err(err) => {
227                        return Err(
228                            Box::new(StorageError(format!("adls download failed: {err}")))
229                                as Box<dyn std::error::Error + Send + Sync>,
230                        )
231                    }
232                };
233                if version.is_none() {
234                    version = Some(resp.blob.properties.etag.to_string());
235                }
236                let bytes = resp.data.collect().await.map_err(|err| {
237                    Box::new(StorageError(format!("adls download read failed: {err}")))
238                        as Box<dyn std::error::Error + Send + Sync>
239                })?;
240                body.extend_from_slice(&bytes);
241            }
242            Ok(version.map(|version| StoredObject { body, version }))
243        })
244    }
245
246    fn write_object_conditional(
247        &self,
248        uri: &str,
249        expected_version: Option<&str>,
250        body: &[u8],
251    ) -> FloeResult<ConditionalWrite> {
252        let key = adls_key_from_uri(uri)?;
253        let client = self.container_client.clone();
254        let body = body.to_vec();
255        self.runtime.block_on(async move {
256            let condition = expected_version
257                .map(|version| IfMatchCondition::Match(version.to_string()))
258                .unwrap_or_else(|| IfMatchCondition::NotMatch("*".to_string()));
259            match client
260                .blob_client(key)
261                .put_block_blob(body)
262                .if_match(condition)
263                .content_type("application/json")
264                .into_future()
265                .await
266            {
267                Ok(resp) => Ok(ConditionalWrite::Written { version: resp.etag }),
268                Err(err) if is_precondition(&err) => Ok(ConditionalWrite::Conflict),
269                Err(err) => Err(Box::new(StorageError(format!("adls upload failed: {err}")))
270                    as Box<dyn std::error::Error + Send + Sync>),
271            }
272        })
273    }
274
275    fn delete_object_conditional(
276        &self,
277        uri: &str,
278        expected_version: Option<&str>,
279    ) -> FloeResult<ConditionalWrite> {
280        let Some(expected_version) = expected_version else {
281            return Ok(ConditionalWrite::Written {
282                version: "deleted".to_string(),
283            });
284        };
285        let key = adls_key_from_uri(uri)?;
286        let client = self.container_client.clone();
287        self.runtime.block_on(async move {
288            let request = client
289                .blob_client(key)
290                .delete()
291                .if_match(IfMatchCondition::Match(expected_version.to_string()));
292            match request.into_future().await {
293                Ok(_) => Ok(ConditionalWrite::Written {
294                    version: "deleted".to_string(),
295                }),
296                Err(err) if is_precondition(&err) => Ok(ConditionalWrite::Conflict),
297                Err(err) if is_not_found(&err) => Ok(ConditionalWrite::Written {
298                    version: "deleted".to_string(),
299                }),
300                Err(err) => Err(Box::new(StorageError(format!("adls delete failed: {err}")))
301                    as Box<dyn std::error::Error + Send + Sync>),
302            }
303        })
304    }
305}
306
307fn adls_key_from_uri(uri: &str) -> FloeResult<String> {
308    let key = uri
309        .split_once(".dfs.core.windows.net/")
310        .map(|(_, tail)| tail)
311        .unwrap_or("")
312        .trim_start_matches('/')
313        .to_string();
314    if key.is_empty() {
315        return Err(Box::new(StorageError(
316            "adls state operation requires a blob path".to_string(),
317        )));
318    }
319    Ok(key)
320}
321
322fn is_not_found<E: std::fmt::Display>(err: &E) -> bool {
323    let text = err.to_string();
324    text.contains("404") || text.contains("NotFound")
325}
326
327fn is_precondition<E: std::fmt::Display>(err: &E) -> bool {
328    let text = err.to_string();
329    text.contains("412") || text.contains("condition") || text.contains("Condition")
330}
331
332pub fn parse_adls_uri(uri: &str) -> FloeResult<AdlsLocation> {
333    uri::parse_abfs_uri(uri)
334}
335
336pub fn format_abfs_uri(container: &str, account: &str, path: &str) -> String {
337    uri::format_abfs_uri(container, account, path)
338}
339
340pub type AdlsLocation = uri::AdlsLocation;