floe_core/io/write/strategy/
mod.rs1use std::path::Path;
2
3use crate::io::storage::Target;
4use crate::{config, io, ConfigError, FloeResult};
5
6use super::parts;
7
8mod append;
9mod overwrite;
10
11#[derive(Debug, Clone, Copy)]
12pub enum PartScope {
13 Accepted { format: &'static str },
14 Rejected { format: &'static str },
15}
16
17#[derive(Debug, Clone, Copy)]
18pub struct PartSpec {
19 pub extension: &'static str,
20 pub scope: PartScope,
21}
22
23pub struct WriteContext<'a> {
24 pub target: &'a Target,
25 pub cloud: &'a mut io::storage::CloudClient,
26 pub resolver: &'a config::StorageResolver,
27 pub entity: &'a config::EntityConfig,
28}
29
30pub trait ModeStrategy {
31 fn mode(&self) -> config::WriteMode;
32 fn part_allocator(
33 &self,
34 ctx: &mut WriteContext<'_>,
35 spec: PartSpec,
36 ) -> FloeResult<parts::PartNameAllocator>;
37}
38
39#[derive(Debug, Clone)]
40enum CloudProvider {
41 S3,
42 Gcs { bucket: String },
43 Adls { container: String, account: String },
44}
45
46pub fn strategy_for(mode: config::WriteMode) -> &'static dyn ModeStrategy {
47 match mode {
48 config::WriteMode::Overwrite => &overwrite::OVERWRITE_STRATEGY,
49 config::WriteMode::Append => &append::APPEND_STRATEGY,
50 }
51}
52
53pub fn ensure_mode_supported(mode: config::WriteMode) -> FloeResult<()> {
54 match mode {
55 config::WriteMode::Overwrite => Ok(()),
56 config::WriteMode::Append => Ok(()),
57 }
58}
59
60pub fn accepted_parquet_spec() -> PartSpec {
61 PartSpec {
62 extension: "parquet",
63 scope: PartScope::Accepted { format: "parquet" },
64 }
65}
66
67pub fn rejected_csv_spec() -> PartSpec {
68 PartSpec {
69 extension: "csv",
70 scope: PartScope::Rejected { format: "csv" },
71 }
72}
73
74pub fn append_part_allocator(
75 _ctx: &mut WriteContext<'_>,
76 spec: PartSpec,
77) -> FloeResult<parts::PartNameAllocator> {
78 Ok(parts::PartNameAllocator::unique(spec.extension))
79}
80
81pub fn overwrite_part_allocator(
82 ctx: &mut WriteContext<'_>,
83 spec: PartSpec,
84) -> FloeResult<parts::PartNameAllocator> {
85 match ctx.target {
86 Target::Local { base_path, .. } => {
87 let base_path = Path::new(base_path);
88 let _ = parts::clear_local_part_files(base_path, spec.extension)?;
89 Ok(parts::PartNameAllocator::from_next_index(0, spec.extension))
90 }
91 Target::S3 { .. } | Target::Gcs { .. } | Target::Adls { .. } => {
92 clear_cloud_parts(ctx, spec)?;
93 Ok(parts::PartNameAllocator::from_next_index(0, spec.extension))
94 }
95 }
96}
97
98fn clear_cloud_parts(ctx: &mut WriteContext<'_>, spec: PartSpec) -> FloeResult<()> {
99 let (list_prefix, objects) = list_part_objects(ctx, spec)?;
100 let client = ctx
101 .cloud
102 .client_for(ctx.resolver, ctx.target.storage(), ctx.entity)?;
103 for object in objects
104 .into_iter()
105 .filter(|obj| obj.key.starts_with(&list_prefix))
106 .filter(|obj| parts::is_part_key(&obj.key, spec.extension))
107 {
108 client.delete_object(&object.uri)?;
109 }
110 Ok(())
111}
112
113pub(crate) fn list_part_objects(
114 ctx: &mut WriteContext<'_>,
115 spec: PartSpec,
116) -> FloeResult<(String, Vec<io::storage::ObjectRef>)> {
117 match ctx.target {
118 Target::S3 {
119 storage, base_key, ..
120 } => {
121 let provider = CloudProvider::S3;
122 let list_prefix = list_prefix(ctx.entity, base_key, &provider, spec)?;
123 let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
124 let objects = client.list(&list_prefix)?;
125 Ok((list_prefix, objects))
126 }
127 Target::Gcs {
128 storage,
129 bucket,
130 base_key,
131 ..
132 } => {
133 let provider = CloudProvider::Gcs {
134 bucket: bucket.clone(),
135 };
136 let list_prefix = list_prefix(ctx.entity, base_key, &provider, spec)?;
137 let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
138 let objects = client.list(&list_prefix)?;
139 Ok((list_prefix, objects))
140 }
141 Target::Adls {
142 storage,
143 container,
144 account,
145 base_path,
146 ..
147 } => {
148 let provider = CloudProvider::Adls {
149 container: container.clone(),
150 account: account.clone(),
151 };
152 let list_prefix = list_prefix(ctx.entity, base_path, &provider, spec)?;
153 let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
154 let objects = client.list(&list_prefix)?;
155 Ok((list_prefix, objects))
156 }
157 Target::Local { .. } => Err(Box::new(ConfigError(
158 "cloud part listing requested for local target".to_string(),
159 ))),
160 }
161}
162
163fn list_prefix(
164 entity: &config::EntityConfig,
165 base_path: &str,
166 provider: &CloudProvider,
167 spec: PartSpec,
168) -> FloeResult<String> {
169 let prefix = base_path.trim_matches('/');
170 if prefix.is_empty() {
171 return Err(Box::new(prefix_error(entity, provider, spec)));
172 }
173 Ok(format!("{prefix}/"))
174}
175
176fn prefix_error(
177 entity: &config::EntityConfig,
178 provider: &CloudProvider,
179 spec: PartSpec,
180) -> ConfigError {
181 match (&spec.scope, provider) {
182 (PartScope::Accepted { format }, CloudProvider::S3) => ConfigError(format!(
183 "entity.name={} sink.accepted.path must not be bucket root for s3 {format} outputs",
184 entity.name
185 )),
186 (PartScope::Accepted { format }, CloudProvider::Gcs { bucket }) => ConfigError(format!(
187 "entity.name={} sink.accepted.path must not be bucket root for gcs {format} outputs (bucket={})",
188 entity.name, bucket
189 )),
190 (PartScope::Accepted { format }, CloudProvider::Adls { container, account }) => {
191 ConfigError(format!(
192 "entity.name={} sink.accepted.path must not be container root for adls {format} outputs (container={}, account={})",
193 entity.name, container, account
194 ))
195 }
196 (PartScope::Rejected { .. }, CloudProvider::S3) => ConfigError(format!(
197 "entity.name={} sink.rejected.path must not be bucket root for s3 outputs",
198 entity.name
199 )),
200 (PartScope::Rejected { .. }, CloudProvider::Gcs { .. }) => ConfigError(format!(
201 "entity.name={} sink.rejected.path must not be bucket root for gcs outputs",
202 entity.name
203 )),
204 (PartScope::Rejected { .. }, CloudProvider::Adls { .. }) => ConfigError(format!(
205 "entity.name={} sink.rejected.path must not be container root for adls outputs",
206 entity.name
207 )),
208 }
209}