#![allow(unused_imports)]
use super::*;
pub(crate) fn write_edit_inputs(
request: &EditRequest,
dir: &FsPath,
) -> Result<(Vec<PathBuf>, Option<PathBuf>, String), String> {
let mut ref_paths = Vec::new();
for (index, upload) in request.refs.iter().enumerate() {
let ext = FsPath::new(&upload.name)
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("png");
let path = dir.join(format!("ref-{index}.{ext}"));
fs::write(&path, &upload.bytes).map_err(|error| error.to_string())?;
ref_paths.push(path);
}
let mask_path = if let Some(mask) = &request.mask {
let path = dir.join("mask.png");
fs::write(&path, &mask.bytes).map_err(|error| error.to_string())?;
Some(path)
} else {
None
};
let selection_hint_path = if let Some(hint) = &request.selection_hint {
let path = dir.join("selection-hint.png");
fs::write(&path, &hint.bytes).map_err(|error| error.to_string())?;
Some(path)
} else {
None
};
let edit_region_mode = edit_region_mode_for_request(request);
if edit_region_mode == "none" && (mask_path.is_some() || selection_hint_path.is_some()) {
return Err("当前凭证不支持局部编辑。请切换到「多图参考」或更换凭证。".to_string());
}
if edit_region_mode == "reference-hint"
&& let Some(path) = &selection_hint_path
{
ref_paths.push(path.clone());
}
Ok((ref_paths, mask_path, edit_region_mode))
}
pub(crate) fn edit_region_mode_for_request(request: &EditRequest) -> String {
if request.mask.is_some() || request.selection_hint.is_some() {
provider_edit_region_mode(request.provider.as_deref())
} else {
"none".to_string()
}
}
pub(crate) fn edit_request_metadata(request: &EditRequest) -> Value {
let edit_region_mode = edit_region_mode_for_request(request);
json!({
"prompt": request.prompt,
"provider": request.provider,
"size": request.size,
"format": request.format,
"quality": request.quality,
"background": request.background,
"n": request.n,
"compression": request.compression,
"input_fidelity": request.input_fidelity,
"moderation": request.moderation,
"storage_targets": request.storage_targets,
"fallback_targets": request.fallback_targets,
"ref_count": request.refs.len(),
"has_mask": request.mask.is_some(),
"selection_hint": request.selection_hint.is_some(),
"edit_region_mode": edit_region_mode,
})
}
pub(crate) fn run_edit_request(
mut request: EditRequest,
fallback_id: String,
dir: PathBuf,
stream: Option<StreamContext>,
) -> Result<Value, String> {
if request.prompt.trim().is_empty() {
return Err("Prompt is required.".to_string());
}
if request.refs.is_empty() {
return Err("At least one reference image is required.".to_string());
}
let output_count = requested_n(request.n)?;
if request.n.is_some() {
request.n = Some(output_count);
}
let (ref_paths, mask_path, edit_region_mode) = write_edit_inputs(&request, &dir)?;
let provider_supports_n = provider_supports_n(request.provider.as_deref());
let payload = if provider_supports_n || output_count == 1 {
let out = dir.join(format!(
"out.{}",
output_extension(request.format.as_deref())
));
cli_json_result(&edit_args_with_recovery(
&request,
&ref_paths,
if edit_region_mode == "native-mask" {
mask_path.as_deref()
} else {
None
},
&out,
provider_supports_n,
Some((&fallback_id, &dir)),
))?
} else {
let recovery_targets = (0..output_count)
.map(|index| {
(
batch_recovery_job_id(&fallback_id, index),
batch_recovery_job_dir(&dir, index),
)
})
.collect::<Vec<_>>();
let arg_sets = recovery_targets
.iter()
.enumerate()
.map(|(index, (recovery_job_id, recovery_job_dir))| {
edit_args_with_recovery(
&request,
&ref_paths,
if edit_region_mode == "native-mask" {
mask_path.as_deref()
} else {
None
},
&batch_output_path(&dir, request.format.as_deref(), index as u8),
false,
Some((recovery_job_id.as_str(), recovery_job_dir.as_path())),
)
})
.collect::<Vec<_>>();
let partials = Arc::new(Mutex::new(Vec::<Value>::new()));
let partials_for_cb = partials.clone();
let stream_for_cb = stream.clone();
let batch = run_payloads_concurrently_streaming(arg_sets, move |index, payload| {
if let Some(ctx) = &stream_for_cb {
let mut list = partials_for_cb
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
apply_partial_output(ctx, &mut list, index, payload);
}
});
let child_dirs = recovery_targets
.iter()
.map(|(_, recovery_dir)| recovery_dir.clone())
.collect::<Vec<_>>();
let merged = merge_batch_payloads(
"images edit",
output_count.into(),
batch.payloads,
batch.errors,
);
let generation_slots =
generation_slots_from_batch_payload(output_count.into(), &merged, &child_dirs);
let outputs_present = merged
.get("output")
.and_then(|output| output.get("files"))
.and_then(Value::as_array)
.map(Vec::len)
.unwrap_or(0);
let failures = merged
.get("batch")
.and_then(|batch| batch.get("failure_count"))
.and_then(Value::as_u64)
.unwrap_or(0) as usize;
write_batch_recovery_summary(
&fallback_id,
&dir,
&child_dirs,
outputs_present,
failures,
generation_slots,
)
.map_err(app_error)?;
merged
};
let request_meta = edit_request_metadata(&request);
let job = job_from_payload(&payload, &fallback_id, "images edit", request_meta);
let event_type = terminal_event_type(job.get("status").and_then(Value::as_str));
Ok(json!({
"job_id": job.get("id").cloned().unwrap_or(Value::Null),
"job": job,
"events": [{
"seq": 1,
"kind": "local",
"type": event_type,
"data": {"status": job.get("status"), "output": payload.get("output"), "error": payload.get("error")}
}],
"payload": payload,
}))
}