gproxy_protocol/transform/openai/create_image/openai_response/
stream.rs1use std::collections::HashMap;
2
3use crate::openai::count_tokens::types as ot;
4use crate::openai::create_image::stream::ImageGenerationStreamEvent;
5use crate::openai::create_response::response::ResponseBody;
6use crate::openai::create_response::stream::ResponseStreamEvent;
7use crate::openai::create_response::types as rt;
8use crate::transform::openai::create_image::utils::{
9 best_effort_image_usage_from_response_usage, stream_background_from_response_config,
10 stream_error_from_response_error, stream_output_format_from_response_config,
11 stream_quality_from_response_config_for_create_image,
12 stream_size_from_response_config_for_create_image,
13};
14
15#[derive(Debug, Clone, Default)]
23pub struct ResponseStreamToImageStream {
24 created_at: u64,
25 background: Option<ot::ResponseImageGenerationBackground>,
26 output_format: Option<ot::ResponseImageGenerationOutputFormat>,
27 quality: Option<ot::ResponseImageGenerationQuality>,
28 size: Option<ot::ResponseImageGenerationSize>,
29 usage: Option<rt::ResponseUsage>,
30 results: HashMap<String, String>,
32 partial_count: u32,
34 finished: bool,
35}
36
37impl ResponseStreamToImageStream {
38 fn update_config_from_response(&mut self, response: &ResponseBody) {
39 self.created_at = response.created_at;
40 if let Some(usage) = response.usage.as_ref() {
41 self.usage = Some(usage.clone());
42 }
43
44 for tool in &response.tools {
45 let rt::ResponseTool::ImageGeneration(image_tool) = tool else {
46 continue;
47 };
48 if let Some(ref bg) = image_tool.background {
49 self.background = Some(bg.clone());
50 }
51 if let Some(ref fmt) = image_tool.output_format {
52 self.output_format = Some(fmt.clone());
53 }
54 if let Some(ref q) = image_tool.quality {
55 self.quality = Some(q.clone());
56 }
57 if let Some(ref s) = image_tool.size {
58 self.size = Some(s.clone());
59 }
60 }
61 }
62
63 fn collect_image_result(&mut self, item: &rt::ResponseOutputItem) {
64 let rt::ResponseOutputItem::ImageGenerationCall(call) = item else {
65 return;
66 };
67 if let Some(result) = call.result.as_deref().filter(|s| !s.is_empty()) {
68 self.results.insert(call.id.clone(), result.to_string());
69 }
70 }
71
72 fn emit_partial(&mut self, b64_json: String, out: &mut Vec<ImageGenerationStreamEvent>) {
73 let index = self.partial_count;
74 self.partial_count += 1;
75 out.push(ImageGenerationStreamEvent::PartialImage {
76 b64_json,
77 background: stream_background_from_response_config(self.background.as_ref()),
78 created_at: self.created_at,
79 output_format: stream_output_format_from_response_config(self.output_format.as_ref()),
80 partial_image_index: index,
81 quality: stream_quality_from_response_config_for_create_image(self.quality.as_ref()),
82 size: stream_size_from_response_config_for_create_image(self.size.as_ref()),
83 });
84 }
85
86 fn emit_completed(&mut self, b64_json: String, out: &mut Vec<ImageGenerationStreamEvent>) {
87 out.push(ImageGenerationStreamEvent::Completed {
88 b64_json,
89 background: stream_background_from_response_config(self.background.as_ref()),
90 created_at: self.created_at,
91 output_format: stream_output_format_from_response_config(self.output_format.as_ref()),
92 quality: stream_quality_from_response_config_for_create_image(self.quality.as_ref()),
93 size: stream_size_from_response_config_for_create_image(self.size.as_ref()),
94 usage: best_effort_image_usage_from_response_usage(self.usage.as_ref()),
95 });
96 }
97
98 pub fn on_event(
99 &mut self,
100 event: ResponseStreamEvent,
101 out: &mut Vec<ImageGenerationStreamEvent>,
102 ) {
103 if self.finished {
104 return;
105 }
106
107 match event {
108 ResponseStreamEvent::Created { response, .. }
110 | ResponseStreamEvent::Queued { response, .. }
111 | ResponseStreamEvent::InProgress { response, .. } => {
112 self.update_config_from_response(&response);
113 for item in &response.output {
114 self.collect_image_result(item);
115 }
116 }
117
118 ResponseStreamEvent::OutputItemAdded { item, .. }
120 | ResponseStreamEvent::OutputItemDone { item, .. } => {
121 self.collect_image_result(&item);
122 }
123
124 ResponseStreamEvent::ImageGenerationCallPartialImage {
126 partial_image_b64, ..
127 } => {
128 self.emit_partial(partial_image_b64, out);
129 }
130
131 ResponseStreamEvent::ImageGenerationCallCompleted { item_id, .. } => {
133 if let Some(b64) = self.results.remove(&item_id) {
134 self.results.insert(item_id, b64);
136 }
137 }
138
139 ResponseStreamEvent::Completed { response, .. } => {
141 self.update_config_from_response(&response);
142 for item in &response.output {
143 self.collect_image_result(item);
144 }
145 self.finalize(out);
146 }
147
148 ResponseStreamEvent::Incomplete { response, .. } => {
150 self.update_config_from_response(&response);
151 for item in &response.output {
152 self.collect_image_result(item);
153 }
154 self.finalize(out);
155 }
156
157 ResponseStreamEvent::Failed { response, .. } => {
159 self.update_config_from_response(&response);
160 let message = response
161 .error
162 .map(|e| e.message)
163 .unwrap_or_else(|| "image generation failed".to_string());
164 out.push(ImageGenerationStreamEvent::Error {
165 error: stream_error_from_response_error(None, message, None),
166 });
167 self.finished = true;
168 }
169
170 ResponseStreamEvent::Error { error, .. } => {
172 out.push(ImageGenerationStreamEvent::Error {
173 error: stream_error_from_response_error(error.code, error.message, error.param),
174 });
175 self.finished = true;
176 }
177
178 _ => {}
180 }
181 }
182
183 fn finalize(&mut self, out: &mut Vec<ImageGenerationStreamEvent>) {
184 if self.finished {
185 return;
186 }
187 self.finished = true;
188
189 let results = std::mem::take(&mut self.results);
191 for (_item_id, b64) in results {
192 self.emit_completed(b64, out);
193 }
194 }
195
196 pub fn finish(&mut self, out: &mut Vec<ImageGenerationStreamEvent>) {
197 self.finalize(out);
198 }
199}